diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index 76f6d7aeca0d..77ee313687fc 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -5,11 +5,11 @@ import sys import zipfile -# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 450 MiB +# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 500 MiB # Note that we have 800 MiB quota, please use it wisely. # See https://github.com/pypi/support/issues/6326 . # Please also sync the value with the one in Dockerfile. -VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 450)) +VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 500)) def print_top_10_largest_files(zip_file): diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml new file mode 100644 index 000000000000..56ec933c9cc0 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml @@ -0,0 +1,12 @@ +# For vllm script, with -t option (tensor parallel size). +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1 +model_name: "HandH1998/QQQ-Llama-3-8b-g128" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.419 + - name: "exact_match,flexible-extract" + value: 0.416 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-4-Maverick-17B-128E-Instruct-FP8-MM.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-4-Maverick-17B-128E-Instruct-FP8-MM.yaml new file mode 100644 index 000000000000..ccb4f84201b7 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-4-Maverick-17B-128E-Instruct-FP8-MM.yaml @@ -0,0 +1,12 @@ +# For hf script, without -t option (tensor parallel size). +# bash .buildkite/lm-eval-harness/run-lm-eval-chartqa-vllm-vlm-baseline.sh -m meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 -l 100 -t 8 +model_name: "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8" +backend: "vllm-vlm" +tasks: +- name: "chartqa" + metrics: + - name: "relaxed_accuracy,none" + # TODO(zhewenl): model card is 0.90, but the actual score is 0.80. + value: 0.80 +limit: 100 +num_fewshot: 0 diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-4-Maverick-17B-128E-Instruct-FP8.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-4-Maverick-17B-128E-Instruct-FP8.yaml new file mode 100644 index 000000000000..46f1a9fbf6ff --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-4-Maverick-17B-128E-Instruct-FP8.yaml @@ -0,0 +1,10 @@ +# For hf script, without -t option (tensor parallel size). +# bash .buildkite/lm-eval-harness/run-lm-eval-mmlupro-vllm-baseline.sh -m meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 -l 250 -t 8 -f 5 +model_name: "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8" +tasks: +- name: "mmlu_pro" + metrics: + - name: "exact_match,custom-extract" + value: 0.80 +limit: 250 # will run on 250 * 14 subjects = 3500 samples +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml b/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml index a2f235f48581..aa4fb9fa03d6 100644 --- a/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml +++ b/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml @@ -1,4 +1,5 @@ -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic -b auto -l 1319 -f 5 -t 1 +# For vllm script, with -t option (tensor parallel size) +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic -l 1319 -t 1 model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic" tasks: - name: "gsm8k" diff --git a/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-7B-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-7B-Instruct.yaml new file mode 100644 index 000000000000..5f3c31743e75 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-7B-Instruct.yaml @@ -0,0 +1,12 @@ +# For vllm script, with -t option (tensor parallel size). +# bash .buildkite/lm-eval-harness/run-lm-eval-chartqa-vllm-vlm-baseline.sh -m Qwen/Qwen2.5-VL-7B-Instruct -l 2500 -t 1 + +model_name: "Qwen/Qwen2.5-VL-7B-Instruct" +backend: "vllm-vlm" +tasks: +- name: "chartqa" + metrics: + - name: "relaxed_accuracy,none" + value: 0.855 +limit: 2500 +num_fewshot: 0 diff --git a/.buildkite/lm-eval-harness/configs/models-large-h100.txt b/.buildkite/lm-eval-harness/configs/models-large-h100.txt new file mode 100644 index 000000000000..4fb0b84bc4d8 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/models-large-h100.txt @@ -0,0 +1 @@ +Meta-Llama-4-Maverick-17B-128E-Instruct-FP8.yaml diff --git a/.buildkite/lm-eval-harness/configs/models-mm-large-h100.txt b/.buildkite/lm-eval-harness/configs/models-mm-large-h100.txt new file mode 100644 index 000000000000..91e22b6459c1 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/models-mm-large-h100.txt @@ -0,0 +1 @@ +Meta-Llama-4-Maverick-17B-128E-Instruct-FP8-MM.yaml diff --git a/.buildkite/lm-eval-harness/configs/models-mm-small.txt b/.buildkite/lm-eval-harness/configs/models-mm-small.txt new file mode 100644 index 000000000000..1097d220245f --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/models-mm-small.txt @@ -0,0 +1 @@ +Qwen2.5-VL-7B-Instruct.yaml \ No newline at end of file diff --git a/.buildkite/lm-eval-harness/run-lm-eval-chartqa-vllm-vlm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-chartqa-vllm-vlm-baseline.sh new file mode 100755 index 000000000000..c8db951381b0 --- /dev/null +++ b/.buildkite/lm-eval-harness/run-lm-eval-chartqa-vllm-vlm-baseline.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# We can use this script to compute baseline accuracy on chartqa for vllm. +# +# Make sure you have lm-eval-harness installed: +# pip install lm-eval==0.4.9 + +usage() { + echo`` + echo "Runs lm eval harness on ChartQA using multimodal vllm." + echo "This pathway is intended to be used to create baselines for " + echo "our correctness tests in vllm's CI." + echo + echo "usage: ${0} " + echo + echo " -m - huggingface stub or local directory of the model" + echo " -l - limit number of samples to run" + echo " -t - tensor parallel size to run at" + echo +} + +while getopts "m:l:t:" OPT; do + case ${OPT} in + m ) + MODEL="$OPTARG" + ;; + l ) + LIMIT="$OPTARG" + ;; + t ) + TP_SIZE="$OPTARG" + ;; + \? ) + usage + exit 1 + ;; + esac +done + +lm_eval --model vllm-vlm \ + --model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE" \ + --tasks chartqa \ + --batch_size auto \ + --apply_chat_template \ + --limit $LIMIT diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh old mode 100644 new mode 100755 diff --git a/.buildkite/lm-eval-harness/run-lm-eval-mmlupro-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-mmlupro-vllm-baseline.sh new file mode 100644 index 000000000000..d85a1721db9a --- /dev/null +++ b/.buildkite/lm-eval-harness/run-lm-eval-mmlupro-vllm-baseline.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# We can use this script to compute baseline accuracy on MMLUPRO for vllm. +# We use this for fp8, which HF does not support. +# +# Make sure you have lm-eval-harness installed: +# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] + +usage() { + echo`` + echo "Runs lm eval harness on MMLU Pro using huggingface transformers." + echo "This pathway is intended to be used to create baselines for " + echo "our automated nm-test-accuracy workflow" + echo + echo "usage: ${0} " + echo + echo " -m - huggingface stub or local directory of the model" + echo " -l - limit number of samples to run" + echo " -f - number of fewshot samples to use" + echo " -t - tensor parallel size to run at" + echo +} + +while getopts "m:b:l:f:t:" OPT; do + case ${OPT} in + m ) + MODEL="$OPTARG" + ;; + b ) + BATCH_SIZE="$OPTARG" + ;; + l ) + LIMIT="$OPTARG" + ;; + f ) + FEWSHOT="$OPTARG" + ;; + t ) + TP_SIZE="$OPTARG" + ;; + \? ) + usage + exit 1 + ;; + esac +done + +lm_eval --model vllm \ + --model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,trust_remote_code=true,max_model_len=4096" \ + --tasks mmlu_pro --num_fewshot "$FEWSHOT" --limit "$LIMIT" \ + --batch_size auto diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py index ceea01166b7f..f10de82b1d8e 100644 --- a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py @@ -19,21 +19,27 @@ def launch_lm_eval(eval_config, tp_size): trust_remote_code = eval_config.get("trust_remote_code", False) max_model_len = eval_config.get("max_model_len", 4096) + batch_size = eval_config.get("batch_size", "auto") + backend = eval_config.get("backend", "vllm") model_args = ( f"pretrained={eval_config['model_name']}," f"tensor_parallel_size={tp_size}," f"enforce_eager=true," f"add_bos_token=true," f"trust_remote_code={trust_remote_code}," - f"max_model_len={max_model_len}" + f"max_model_len={max_model_len}," ) results = lm_eval.simple_evaluate( - model="vllm", + model=backend, model_args=model_args, tasks=[task["name"] for task in eval_config["tasks"]], num_fewshot=eval_config["num_fewshot"], limit=eval_config["limit"], - batch_size="auto", + # TODO(yeq): using chat template w/ fewshot_as_multiturn is supposed help + # text models. however, this is regressing measured strict-match for + # existing text models in CI, so only apply it for mm. + apply_chat_template=backend == "vllm-vlm", + batch_size=batch_size, ) return results diff --git a/.buildkite/nightly-benchmarks/scripts/compare-json-results.py b/.buildkite/nightly-benchmarks/scripts/compare-json-results.py index 5ea5a50a258a..c8bf7b045366 100644 --- a/.buildkite/nightly-benchmarks/scripts/compare-json-results.py +++ b/.buildkite/nightly-benchmarks/scripts/compare-json-results.py @@ -7,6 +7,7 @@ import pandas as pd +pd.options.display.float_format = "{:.2f}".format plotly_found = util.find_spec("plotly.express") is not None @@ -109,7 +110,10 @@ def compare_data_columns( if len(compare_frames) >= 2: base = compare_frames[0] current = compare_frames[-1] - ratio = current / base + if "P99" in data_column or "Median" in data_column: + ratio = base / current # for latency + else: + ratio = current / base ratio = ratio.mask(base == 0) # avoid inf when baseline is 0 ratio.name = f"Ratio 1 vs {len(compare_frames)}" frames.append(ratio) @@ -199,6 +203,71 @@ def split_json_by_tp_pp( return saved_paths +def _add_limit_line(fig, y_value, label): + # Visible dashed line + annotation + fig.add_hline( + y=y_value, + line_dash="dash", + line_color="red" if "ttft" in label.lower() else "blue", + annotation_text=f"{label}: {y_value} ms", + annotation_position="top left", + ) + # Optional: add a legend item (as a transparent helper trace) + if plot and plotly_found: + import plotly.graph_objects as go + + fig.add_trace( + go.Scatter( + x=[None], + y=[None], + mode="lines", + line=dict( + dash="dash", color="red" if "ttft" in label.lower() else "blue" + ), + name=f"{label}", + ) + ) + + +def _find_concurrency_col(df: pd.DataFrame) -> str: + for c in [ + "# of max concurrency.", + "# of max concurrency", + "Max Concurrency", + "max_concurrency", + "Concurrency", + ]: + if c in df.columns: + return c + # Fallback: guess an integer-like column (harmless if unused) + for c in df.columns: + if df[c].dtype.kind in "iu" and df[c].nunique() > 1 and df[c].min() >= 1: + return c + return "# of max concurrency." + + +def _highlight_threshold( + df: pd.DataFrame, threshold: float +) -> "pd.io.formats.style.Styler": + """Highlight numeric per-configuration columns with value <= threshold.""" + conc_col = _find_concurrency_col(df) + key_cols = [ + c + for c in ["Model", "Dataset Name", "Input Len", "Output Len", conc_col] + if c in df.columns + ] + conf_cols = [ + c for c in df.columns if c not in key_cols and not str(c).startswith("Ratio") + ] + conf_cols = [c for c in conf_cols if pd.api.types.is_numeric_dtype(df[c])] + return df.style.map( + lambda v: "background-color:#e6ffe6;font-weight:bold;" + if pd.notna(v) and v <= threshold + else "", + subset=conf_cols, + ) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -220,6 +289,26 @@ def split_json_by_tp_pp( default="# of max concurrency.", help="column name to use as X Axis in comparison graph", ) + parser.add_argument( + "-l", + "--latency", + type=str, + default="p99", + help="take median|p99 for latency like TTFT/TPOT", + ) + parser.add_argument( + "--ttft-max-ms", + type=float, + default=3000.0, + help="Reference limit for TTFT plots (ms)", + ) + parser.add_argument( + "--tpot-max-ms", + type=float, + default=100.0, + help="Reference limit for TPOT plots (ms)", + ) + args = parser.parse_args() drop_column = "P99" @@ -234,12 +323,22 @@ def split_json_by_tp_pp( "# of max concurrency.", "qps", ] - data_cols_to_compare = ["Output Tput (tok/s)", "Median TTFT (ms)", "Median"] - html_msgs_for_data_cols = [ - "Compare Output Tokens /n", - "Median TTFT /n", - "Median TPOT /n", - ] + + if "median" in args.latency: + data_cols_to_compare = ["Output Tput (tok/s)", "Median TTFT (ms)", "Median"] + html_msgs_for_data_cols = [ + "Compare Output Tokens /n", + "Median TTFT /n", + "Median TPOT /n", + ] + drop_column = "P99" + elif "p99" in args.latency: + data_cols_to_compare = ["Output Tput (tok/s)", "P99 TTFT (ms)", "P99"] + html_msgs_for_data_cols = [ + "Compare Output Tokens /n", + "P99 TTFT /n", + "P99 TPOT /n", + ] if len(args.file) == 1: files = split_json_by_tp_pp(args.file[0], output_root="splits") @@ -275,33 +374,83 @@ def split_json_by_tp_pp( f"Expected subset: {filtered_info_cols}, " f"but DataFrame has: {list(output_df.columns)}" ) - output_df_sorted = output_df.sort_values(by=existing_group_cols) + # output_df_sorted = output_df.sort_values(by=existing_group_cols) + output_df_sorted = output_df.sort_values(by=args.xaxis) output_groups = output_df_sorted.groupby(existing_group_cols, dropna=False) for name, group in output_groups: - html = group.to_html() - text_file.write(html_msgs_for_data_cols[i]) - text_file.write(html) - - if plot and plotly_found: - import plotly.express as px - - df = group[raw_data_cols] - df_sorted = df.sort_values(by=info_cols[y_axis_index]) - # Melt DataFrame for plotting - df_melted = df_sorted.melt( - id_vars=info_cols[y_axis_index], - var_name="Configuration", - value_name=data_cols_to_compare[i], + group_name = ( + ",".join(map(str, name)).replace(",", "_").replace("/", "-") + ) + group_html_name = "perf_comparison_" + group_name + ".html" + + metric_name = str(data_cols_to_compare[i]).lower() + if "tok/s" in metric_name: + html = group.to_html() + elif "ttft" in metric_name: + styler = _highlight_threshold(group, args.ttft_max_ms).format( + {c: "{:.2f}" for c in group.select_dtypes("number").columns}, + na_rep="—", + ) + html = styler.to_html( + table_attributes='border="1" class="dataframe"' + ) + elif ( + "tpot" in metric_name + or "median" in metric_name + or "p99" in metric_name + ): + styler = _highlight_threshold(group, args.tpot_max_ms).format( + {c: "{:.2f}" for c in group.select_dtypes("number").columns}, + na_rep="—", ) - title = data_cols_to_compare[i] + " vs " + info_cols[y_axis_index] - # Create Plotly line chart - fig = px.line( - df_melted, - x=info_cols[y_axis_index], - y=data_cols_to_compare[i], - color="Configuration", - title=title, - markers=True, + html = styler.to_html( + table_attributes='border="1" class="dataframe"' ) - # Export to HTML - text_file.write(fig.to_html(full_html=True, include_plotlyjs="cdn")) + + text_file.write(html_msgs_for_data_cols[i]) + text_file.write(html) + with open(group_html_name, "a+") as sub_text_file: + sub_text_file.write(html_msgs_for_data_cols[i]) + sub_text_file.write(html) + + if plot and plotly_found: + import plotly.express as px + + df = group[raw_data_cols] + df_sorted = df.sort_values(by=info_cols[y_axis_index]) + # Melt DataFrame for plotting + df_melted = df_sorted.melt( + id_vars=info_cols[y_axis_index], + var_name="Configuration", + value_name=data_cols_to_compare[i], + ) + title = ( + data_cols_to_compare[i] + " vs " + info_cols[y_axis_index] + ) + # Create Plotly line chart + fig = px.line( + df_melted, + x=info_cols[y_axis_index], + y=data_cols_to_compare[i], + color="Configuration", + title=title, + markers=True, + ) + + # ---- Add threshold lines based on metric name ---- + if "ttft" in metric_name: + _add_limit_line(fig, args.ttft_max_ms, "TTFT limit") + elif ( + "tpot" in metric_name + or "median" in metric_name + or "p99" in metric_name + ): + _add_limit_line(fig, args.tpot_max_ms, "TPOT limit") + + # Export to HTML + text_file.write( + fig.to_html(full_html=True, include_plotlyjs="cdn") + ) + sub_text_file.write( + fig.to_html(full_html=True, include_plotlyjs="cdn") + ) diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py index a655a650cb32..a7544aeef4c7 100644 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -63,9 +63,11 @@ "mean_ttft_ms": "Mean TTFT (ms)", "median_ttft_ms": "Median TTFT (ms)", "p99_ttft_ms": "P99 TTFT (ms)", + "std_ttft_ms": "STD TTFT (ms)", "mean_tpot_ms": "Mean TPOT (ms)", "median_tpot_ms": "Median", "p99_tpot_ms": "P99", + "std_tpot_ms": "STD TPOT (ms)", "mean_itl_ms": "Mean ITL (ms)", "median_itl_ms": "Median ITL (ms)", "p99_itl_ms": "P99 ITL (ms)", @@ -368,7 +370,7 @@ def parse_client_command(cmd: str) -> dict[str, Any]: # The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...", # we want to turn it into "8xGPUTYPE" df["GPU"] = df["GPU"].apply( - lambda x: f"{len(x.splitlines())}x{x.splitlines()[0]}" + lambda x: "{}x{}".format(len(x.split("\n")), x.split("\n")[0]) ) # get markdown tables diff --git a/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh index c64e5638029e..5a47576483bb 100644 --- a/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh +++ b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh @@ -471,6 +471,11 @@ main() { mkdir -p $RESULTS_FOLDER QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ + # dump vllm info via vllm collect-env + env_output=$(vllm collect-env) + + echo "$env_output" >"$RESULTS_FOLDER/vllm_env.txt" + # benchmarking run_serving_tests $QUICK_BENCHMARK_ROOT/tests/"${SERVING_JSON:-serving-tests$ARCH.json}" run_latency_tests $QUICK_BENCHMARK_ROOT/tests/"${LATENCY_JSON:-latency-tests$ARCH.json}" diff --git a/.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json b/.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json index 569117aae852..77d1694ec864 100644 --- a/.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json +++ b/.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json @@ -1,28 +1,24 @@ [ { - "test_name": "latency_llama8B_tp1", + "test_name": "latency_llama8B_tp2", "environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, "VLLM_CPU_KVCACHE_SPACE": 40 }, "parameters": { "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "load_format": "dummy", - "num_iters_warmup": 5, - "num_iters": 15 - } - }, - { - "test_name": "latency_llama8B_tp4", - "environment_variables": { - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 4, - "load_format": "dummy", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, "num_iters_warmup": 5, "num_iters": 15 } diff --git a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc3.json b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc3.json index ce396d6e54f2..0b1a42e79025 100644 --- a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc3.json +++ b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu-snc3.json @@ -95,6 +95,38 @@ "num_prompts": 200 } }, + { + "test_name": "serving_llama8B_bf16_tp4_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, { "test_name": "serving_llama8B_bf16_tp2pp3_sharegpt", "qps_list": ["inf"], @@ -233,6 +265,41 @@ "num_prompts": 1000 } }, + { + "test_name": "serving_llama8B_bf16_tp4_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, { "test_name": "serving_llama8B_bf16_tp2pp3_random_128_128", "qps_list": ["inf"], @@ -365,6 +432,38 @@ "num_prompts": 200 } }, + { + "test_name": "serving_llama8B_int8_tp4_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, { "test_name": "serving_llama8B_int8_tp2pp3_sharegpt", "qps_list": ["inf"], @@ -503,6 +602,41 @@ "num_prompts": 1000 } }, + { + "test_name": "serving_llama8B_int8_tp4_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, { "test_name": "serving_llama8B_int8_tp2pp3_random_128_128", "qps_list": ["inf"], @@ -638,6 +772,39 @@ "num_prompts": 200 } }, + { + "test_name": "serving_llama8B_int4_tp4_sharegpt", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200 + } + }, { "test_name": "serving_llama8B_int4_tp2pp3_sharegpt", "qps_list": ["inf"], @@ -780,6 +947,42 @@ "num_prompts": 1000 } }, + { + "test_name": "serving_llama8B_int4_tp4_random_128_128", + "qps_list": ["inf"], + "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "quantization": "awq", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 1000 + } + }, { "test_name": "serving_llama8B_int4_tp2pp3_random_128_128", "qps_list": ["inf"], diff --git a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json index e21c8df0a9fe..f792956f3947 100644 --- a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json +++ b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json @@ -2,7 +2,7 @@ { "test_name": "serving_llama8B_tp1_sharegpt", "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "max_concurrency_list": [32], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -28,13 +28,13 @@ "backend": "vllm", "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 + "num_prompts": 32 } }, { "test_name": "serving_llama8B_tp2_sharegpt", "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "max_concurrency_list": [32], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -60,13 +60,13 @@ "backend": "vllm", "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 + "num_prompts": 32 } }, { - "test_name": "serving_llama8B_tp4_sharegpt", + "test_name": "serving_llama8B_tp1_random_128_128", "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "max_concurrency_list": [32], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -76,11 +76,12 @@ }, "server_parameters": { "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 4, + "tensor_parallel_size": 1, "dtype": "bfloat16", "distributed_executor_backend": "mp", "block_size": 128, "trust_remote_code": "", + "enable_chunked_prefill": "", "disable_log_stats": "", "enforce_eager": "", "max_num_batched_tokens": 2048, @@ -90,15 +91,122 @@ "client_parameters": { "model": "meta-llama/Llama-3.1-8B-Instruct", "backend": "vllm", - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200 + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 32 + } + }, + { + "test_name": "serving_llama8B_tp2_random_128_128", + "qps_list": [1, 4, 16, "inf"], + "max_concurrency_list": [32], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 128, + "ignore-eos": "", + "num_prompts": 32 + } + }, + { + "test_name": "serving_llama8B_tp1_random_128_2048", + "qps_list": [1, 4, 16, "inf"], + "max_concurrency_list": [32], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 2048, + "ignore-eos": "", + "num_prompts": 32 + } + }, + { + "test_name": "serving_llama8B_tp2_random_128_2048", + "qps_list": [1, 4, 16, "inf"], + "max_concurrency_list": [32], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Llama-3.1-8B-Instruct", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 128, + "random-output-len": 2048, + "ignore-eos": "", + "num_prompts": 32 } }, { - "test_name": "serving_llama8B_tp4_random_1024_128", + "test_name": "serving_llama8B_tp1_random_2048_128", "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "max_concurrency_list": [32], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -108,7 +216,7 @@ }, "server_parameters": { "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 4, + "tensor_parallel_size": 1, "dtype": "bfloat16", "distributed_executor_backend": "mp", "block_size": 128, @@ -124,16 +232,16 @@ "model": "meta-llama/Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "random", - "random-input-len": 1024, + "random-input-len": 2048, "random-output-len": 128, "ignore-eos": "", - "num_prompts": 100 + "num_prompts": 32 } }, { - "test_name": "serving_llama8B_pp6_random_1024_128", + "test_name": "serving_llama8B_tp2_random_2048_128", "qps_list": [1, 4, 16, "inf"], - "max_concurrency_list": [12, 16, 24, 32, 64, 128, 200], + "max_concurrency_list": [32], "server_environment_variables": { "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, @@ -143,7 +251,7 @@ }, "server_parameters": { "model": "meta-llama/Llama-3.1-8B-Instruct", - "pipeline_parallel_size": 6, + "tensor_parallel_size": 2, "dtype": "bfloat16", "distributed_executor_backend": "mp", "block_size": 128, @@ -159,10 +267,10 @@ "model": "meta-llama/Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "random", - "random-input-len": 1024, + "random-input-len": 2048, "random-output-len": 128, "ignore-eos": "", - "num_prompts": 100 + "num_prompts": 32 } } ] diff --git a/.buildkite/nightly-benchmarks/tests/throughput-tests-cpu.json b/.buildkite/nightly-benchmarks/tests/throughput-tests-cpu.json index 48c015aa8403..dc214ddfb27e 100644 --- a/.buildkite/nightly-benchmarks/tests/throughput-tests-cpu.json +++ b/.buildkite/nightly-benchmarks/tests/throughput-tests-cpu.json @@ -1,29 +1,24 @@ [ { - "test_name": "throughput_llama8B_tp1", + "test_name": "throughput_llama8B_tp2", "environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_SGL_KERNEL": 1, "VLLM_CPU_KVCACHE_SPACE": 40 }, "parameters": { "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 1, - "load_format": "dummy", - "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", - "num_prompts": 200, - "backend": "vllm" - } - }, - { - "test_name": "throughput_llama8B_tp4", - "environment_variables": { - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, - "VLLM_CPU_KVCACHE_SPACE": 40 - }, - "parameters": { - "model": "meta-llama/Llama-3.1-8B-Instruct", - "tensor_parallel_size": 4, - "load_format": "dummy", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "enforce_eager": "", + "max_num_batched_tokens": 2048, + "max_num_seqs": 256, "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", "num_prompts": 200, "backend": "vllm" diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 9cee502015c7..33b7114666fa 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -1,5 +1,5 @@ steps: - # aarch64 + CUDA builds. PyTorch 2.8 aarch64 + CUDA wheel is only available on CUDA 12.9 + # aarch64 + CUDA builds - label: "Build arm64 wheel - CUDA 12.9" depends_on: ~ id: build-wheel-arm64-cuda-12-9 @@ -8,13 +8,28 @@ steps: commands: # #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here: # https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7 - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg VLLM_MAIN_CUDA_VERSION=12.9 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg VLLM_MAIN_CUDA_VERSION=12.9 --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" env: DOCKER_BUILDKIT: "1" + # aarch64 build + - label: "Build arm64 CPU wheel" + depends_on: ~ + id: build-wheel-arm64-cpu + agents: + queue: arm64_cpu_queue_postmerge + commands: + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_BUILD_ACL=ON --tag vllm-ci:build-image --target vllm-build --progress plain -f docker/Dockerfile.cpu ." + - "mkdir artifacts" + - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" + - "bash .buildkite/scripts/upload-wheels.sh" + env: + DOCKER_BUILDKIT: "1" + + # x86 + CUDA builds - label: "Build wheel - CUDA 12.8" depends_on: ~ id: build-wheel-cuda-12-8 @@ -28,33 +43,33 @@ steps: env: DOCKER_BUILDKIT: "1" - - label: "Build wheel - CUDA 12.6" + - label: "Build wheel - CUDA 12.9" depends_on: ~ - id: build-wheel-cuda-12-6 + id: build-wheel-cuda-12-9 agents: queue: cpu_queue_postmerge commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.6.3 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" env: DOCKER_BUILDKIT: "1" - # x86 + CUDA builds - - label: "Build wheel - CUDA 12.9" + - label: "Build wheel - CUDA 13.0" depends_on: ~ - id: build-wheel-cuda-12-9 + id: build-wheel-cuda-13-0 agents: queue: cpu_queue_postmerge commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=13.0.1 --build-arg BUILD_BASE_IMAGE=nvidia/cuda:13.0.1-devel-ubuntu22.04 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" env: DOCKER_BUILDKIT: "1" + # Build release images (12.9) - label: "Build release image (x86)" depends_on: ~ id: build-release-image-x86 @@ -62,13 +77,12 @@ steps: queue: cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ." - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)" # re-tag to default image tag and push, just in case arm64 build fails - "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" - # PyTorch 2.8 aarch64 + CUDA wheel is only available on CUDA 12.9 - label: "Build release image (arm64)" depends_on: ~ id: build-release-image-arm64 @@ -76,7 +90,7 @@ steps: queue: arm64_cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ." - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)" # Add job to create multi-arch manifest @@ -142,6 +156,22 @@ steps: env: DOCKER_BUILDKIT: "1" + - block: "Build arm64 CPU release image" + key: block-arm64-cpu-release-image-build + depends_on: ~ + + - label: "Build and publish arm64 CPU release image" + depends_on: block-arm64-cpu-release-image-build + agents: + queue: arm64_cpu_queue_postmerge + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:latest" + - "docker push public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:$(buildkite-agent meta-data get release-version)" + env: + DOCKER_BUILDKIT: "1" + - label: "Build and publish nightly multi-arch image to DockerHub" depends_on: - create-multi-arch-manifest diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh index 36bcb015d308..39ea18017308 100755 --- a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh @@ -25,25 +25,28 @@ function cpu_tests() { # offline inference podman exec -it "$container_id" bash -c " - set -e - python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" + set -xve + python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" >> $HOME/test_basic.log # Run basic model test podman exec -it "$container_id" bash -c " - set -e + set -evx pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib pip install sentence-transformers datamodel_code_generator - pytest -v -s tests/models/language/generation/test_bart.py -m cpu_model + + # Note: disable Bart until supports V1 + # pytest -v -s tests/models/language/generation/test_bart.py -m cpu_model pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-openai-community/gpt2] pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-facebook/opt-125m] pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-google/gemma-1.1-2b-it] pytest -v -s tests/models/language/pooling/test_classification.py::test_models[float-jason9693/Qwen2.5-1.5B-apeach] - pytest -v -s tests/models/language/pooling/test_embedding.py -m cpu_model" + # TODO: Below test case tests/models/language/pooling/test_embedding.py::test_models[True-ssmits/Qwen2-7B-Instruct-embed-base] fails on ppc64le. Disabling it for time being. + # pytest -v -s tests/models/language/pooling/test_embedding.py -m cpu_model" >> $HOME/test_rest.log } # All of CPU tests are expected to be finished less than 40 mins. export container_id export -f cpu_tests -timeout 40m bash -c cpu_tests +timeout 120m bash -c cpu_tests diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 7512cb1bbed0..7927aef19e4e 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -70,7 +70,7 @@ function cpu_tests() { docker exec cpu-test-"$NUMA_NODE" bash -c " set -e pytest -x -s -v \ - tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs[False-10-32-neuralmagic/Llama-3.2-1B-quantized.w8a8]" + tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs" # Note: disable it until supports V1 # Run AWQ test diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index 2fd7265fa536..250a64fdd071 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -44,6 +44,5 @@ docker run \ pytest -v -s v1/structured_output pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_tree_attention.py pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py - pytest -v -s v1/test_metrics pytest -v -s v1/test_serial_utils.py ' diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh index 43aa8c47be29..945c5e48c009 100644 --- a/.buildkite/scripts/upload-wheels.sh +++ b/.buildkite/scripts/upload-wheels.sh @@ -58,33 +58,25 @@ python3 .buildkite/generate_index.py --wheel "$normal_wheel" aws s3 cp "$wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" -if [[ $normal_wheel == *"cu126"* ]]; then - # if $normal_wheel matches cu126, do not upload the index.html - echo "Skipping index files for cu126 wheels" -elif [[ $normal_wheel == *"cu128"* ]]; then - # if $normal_wheel matches cu128, do not upload the index.html - echo "Skipping index files for cu128 wheels" -else +if [[ $normal_wheel == *"cu129"* ]]; then # only upload index.html for cu129 wheels (default wheels) as it # is available on both x86 and arm64 aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html" aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html" +else + echo "Skipping index files for non-cu129 wheels" fi # generate index for nightly aws s3 cp "$wheel" "s3://vllm-wheels/nightly/" aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/" -if [[ $normal_wheel == *"cu126"* ]]; then - # if $normal_wheel matches cu126, do not upload the index.html - echo "Skipping index files for cu126 wheels" -elif [[ $normal_wheel == *"cu128"* ]]; then - # if $normal_wheel matches cu128, do not upload the index.html - echo "Skipping index files for cu128 wheels" -else +if [[ $normal_wheel == *"cu129"* ]]; then # only upload index.html for cu129 wheels (default wheels) as it # is available on both x86 and arm64 aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html" +else + echo "Skipping index files for non-cu129 wheels" fi aws s3 cp "$wheel" "s3://vllm-wheels/$version/" diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml new file mode 100644 index 000000000000..92e27f143d8d --- /dev/null +++ b/.buildkite/test-amd.yaml @@ -0,0 +1,1326 @@ +# In this file, you can add more tests to run either by adding a new step or +# adding a new command to an existing step. See different options here for examples. + +# This script will be feed into Jinja template in `test-template-aws.j2` at +# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2 +# to generate the final pipeline yaml file. + +# Documentation +# label(str): the name of the test. emojis allowed. +# fast_check(bool): whether to run this on each commit on the fastcheck pipeline. +# torch_nightly(bool): whether to run this on vllm against the torch nightly pipeline. +# fast_check_only(bool): run this test on the fastcheck pipeline only +# optional(bool): never run this test by default (i.e. need to unblock manually) unless it's a scheduled nightly run. +# soft_fail(bool): allow this step to fail without failing the entire pipeline (useful for flaky or experimental tests). +# command(str): the single command to run for tests. incompatible with commands. +# commands(list): the list of commands to run for the test. incompatible with command. +# mirror_hardwares(list): the list of hardware to run the test on as well. currently only supports [amdexperimental] +# gpu(str): override the GPU selection for the test. default is L4 GPUs. supports a100, b200, h200 +# num_gpus(int): override the number of GPUs for the test. defaults to 1 GPU. currently supports 2,4. +# num_nodes(int): whether to simulate multi-node setup by launching multiple containers on one host, +# in this case, commands must be specified. the first command runs on the first host, the second +# command runs on the second host. +# timeout_in_minutes(int): sets a timeout for the step in minutes. if not specified, uses the default timeout. +# parallelism(int): number of parallel jobs to run for this step. enables test sharding using $$BUILDKITE_PARALLEL_JOB +# and $$BUILDKITE_PARALLEL_JOB_COUNT environment variables. +# working_dir(str): specify the place where the command should execute, default to /vllm-workspace/tests +# source_file_dependencies(list): the list of prefixes to opt-in the test for, if empty, the test will always run. + +# When adding a test +# - If the test belongs to an existing group, add it there +# - If the test is short, add to any existing step +# - If the test takes more than 10min, then it is okay to create a new step. +# Note that all steps execute in parallel. + +steps: +##### fast check tests ##### + +- label: Pytorch Nightly Dependency Override Check # 2min + # if this test fails, it means the nightly torch version is not compatible with some + # of the dependencies. Please check the error message and add the package to whitelist + # in /vllm/tools/generate_nightly_torch_test.py + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + soft_fail: true + source_file_dependencies: + - requirements/nightly_torch_test.txt + commands: + - bash standalone_tests/pytorch_nightly_dependency.sh + +- label: Async Engine, Inputs, Utils, Worker Test # 36min + timeout_in_minutes: 50 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - vllm/ + - tests/multimodal + - tests/utils_ + commands: + - pytest -v -s -m 'not cpu_test' multimodal + - pytest -v -s utils_ + +- label: Async Engine, Inputs, Utils, Worker Test (CPU) # 4 mins + timeout_in_minutes: 10 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - vllm/ + - tests/test_inputs.py + - tests/test_outputs.py + - tests/multimodal + - tests/standalone_tests/lazy_imports.py + - tests/transformers_utils + no_gpu: true + commands: + - python3 standalone_tests/lazy_imports.py + - pytest -v -s test_inputs.py + - pytest -v -s test_outputs.py + - pytest -v -s -m 'cpu_test' multimodal + - pytest -v -s transformers_utils + +- label: Python-only Installation Test # 10min + timeout_in_minutes: 20 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - tests/standalone_tests/python_only_compile.sh + - setup.py + commands: + - bash standalone_tests/python_only_compile.sh + +- label: Basic Correctness Test # 20min + timeout_in_minutes: 30 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + # grade: Blocking + fast_check: true + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/basic_correctness/test_basic_correctness + - tests/basic_correctness/test_cpu_offload + - tests/basic_correctness/test_cumem.py + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s basic_correctness/test_cumem.py + - pytest -v -s basic_correctness/test_basic_correctness.py + - pytest -v -s basic_correctness/test_cpu_offload.py + +- label: Entrypoints Unit Tests # 5min + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + # grade: Blocking + timeout_in_minutes: 10 + working_dir: "/vllm-workspace/tests" + fast_check: true + source_file_dependencies: + - vllm/entrypoints + - tests/entrypoints/ + commands: + - pytest -v -s entrypoints/openai/tool_parsers + - pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling + +- label: Entrypoints Integration Test (LLM) # 30min + timeout_in_minutes: 40 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + # grade: Blocking + working_dir: "/vllm-workspace/tests" + fast_check: true + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/entrypoints/llm + - tests/entrypoints/offline_mode + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py + - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process + - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests + +- label: Entrypoints Integration Test (API Server) # 100min + timeout_in_minutes: 130 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + working_dir: "/vllm-workspace/tests" + fast_check: true + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/entrypoints/openai + - tests/entrypoints/test_chat_utils + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py --ignore=entrypoints/openai/tool_parsers/ + - pytest -v -s entrypoints/test_chat_utils.py + +- label: Entrypoints Integration Test (Pooling) + timeout_in_minutes: 50 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + working_dir: "/vllm-workspace/tests" + fast_check: true + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/entrypoints/pooling + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s entrypoints/pooling + +- label: Distributed Tests (4 GPUs) # 35min + timeout_in_minutes: 50 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_4 + # grade: Blocking + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/ + - tests/distributed/test_utils + - tests/distributed/test_pynccl + - tests/distributed/test_events + - tests/compile/test_basic_correctness + - examples/offline_inference/rlhf.py + - examples/offline_inference/rlhf_colocate.py + - tests/examples/offline_inference/data_parallel.py + - tests/v1/distributed + - tests/v1/engine/test_engine_core_client.py + - tests/distributed/test_symm_mem_allreduce.py + commands: + # test with torchrun tp=2 and external_dp=2 + - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py + # test with torchrun tp=2 and pp=2 + - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py + # test with torchrun tp=4 and dp=1 + - TP_SIZE=4 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + # test with torchrun tp=2, pp=2 and dp=1 + - PP_SIZE=2 TP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + # test with torchrun tp=1 and dp=4 with ep + - DP_SIZE=4 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + # test with torchrun tp=2 and dp=2 with ep + - TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py + # test with internal dp + - python3 ../examples/offline_inference/data_parallel.py --enforce-eager + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py + - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py + - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py + - pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp + - pytest -v -s distributed/test_utils.py + - pytest -v -s compile/test_basic_correctness.py + - pytest -v -s distributed/test_pynccl.py + - pytest -v -s distributed/test_events.py + - pytest -v -s distributed/test_symm_mem_allreduce.py + # TODO: create a dedicated test section for multi-GPU example tests + # when we have multiple distributed example tests + - pushd ../examples/offline_inference + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py + - popd + +- label: EPLB Algorithm Test # 5min + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + # grade: Blocking + timeout_in_minutes: 15 + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - vllm/distributed/eplb + - tests/distributed/test_eplb_algo.py + commands: + - pytest -v -s distributed/test_eplb_algo.py + +- label: EPLB Execution Test # 5min + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_4 + # grade: Blocking + timeout_in_minutes: 15 + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/eplb + - tests/distributed/test_eplb_execute.py + commands: + - pytest -v -s distributed/test_eplb_execute.py + +- label: Metrics, Tracing Test # 12min + timeout_in_minutes: 20 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_2 + # grade: Blocking + num_gpus: 2 + source_file_dependencies: + - vllm/ + - tests/v1/tracing + commands: + - "pip install \ + 'opentelemetry-sdk>=1.26.0' \ + 'opentelemetry-api>=1.26.0' \ + 'opentelemetry-exporter-otlp>=1.26.0' \ + 'opentelemetry-semantic-conventions-ai>=0.4.1'" + - pytest -v -s v1/tracing + +##### fast check tests ##### +##### 1 GPU test ##### + +- label: Regression Test # 7min + timeout_in_minutes: 20 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + grade: Blocking + source_file_dependencies: + - vllm/ + - tests/test_regression + commands: + - pip install modelscope + - pytest -v -s test_regression.py + working_dir: "/vllm-workspace/tests" # optional + +- label: Engine Test # 25min + timeout_in_minutes: 40 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + #grade: Blocking + source_file_dependencies: + - vllm/ + - tests/engine + - tests/tokenization + - tests/test_sequence + - tests/test_config + - tests/test_logger + - tests/test_vllm_port + commands: + - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py + # OOM in the CI unless we run this separately + - pytest -v -s tokenization + +- label: V1 Test e2e + engine # 30min + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + # TODO: accuracy does not match, whether setting + # VLLM_USE_FLASHINFER_SAMPLER or not on H100. + - pytest -v -s v1/e2e + - pytest -v -s v1/engine + +- label: V1 Test entrypoints # 35min + timeout_in_minutes: 50 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + - pytest -v -s v1/entrypoints + +- label: V1 Test others # 42min + timeout_in_minutes: 60 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + # split the test to avoid interference + - pytest -v -s -m 'not cpu_test' v1/core + - pytest -v -s v1/executor + - pytest -v -s v1/kv_offload + - pytest -v -s v1/sample + - pytest -v -s v1/logits_processors + - pytest -v -s v1/worker + - pytest -v -s v1/spec_decode + - pytest -v -s -m 'not cpu_test' v1/kv_connector/unit + - pytest -v -s -m 'not cpu_test' v1/metrics + - pytest -v -s v1/test_oracle.py + - pytest -v -s v1/test_request.py + # Integration test for streaming correctness (requires special branch). + - pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api + - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine + +- label: V1 Test others (CPU) # 5 mins + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - vllm/ + - tests/v1 + no_gpu: true + commands: + # split the test to avoid interference + - pytest -v -s -m 'cpu_test' v1/core + - pytest -v -s v1/structured_output + - pytest -v -s v1/test_serial_utils.py + - pytest -v -s -m 'cpu_test' v1/kv_connector/unit + - pytest -v -s -m 'cpu_test' v1/metrics + + +- label: Examples Test # 30min + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + working_dir: "/vllm-workspace/examples" + source_file_dependencies: + - vllm/entrypoints + - examples/ + commands: + - pip install tensorizer # for tensorizer test + - python3 offline_inference/basic/generate.py --model facebook/opt-125m + - python3 offline_inference/basic/generate.py --model meta-llama/Llama-2-13b-chat-hf --cpu-offload-gb 10 + - python3 offline_inference/basic/chat.py + - python3 offline_inference/prefix_caching.py + - python3 offline_inference/llm_engine_example.py + - python3 offline_inference/audio_language.py --seed 0 + - python3 offline_inference/vision_language.py --seed 0 + - python3 offline_inference/vision_language_pooling.py --seed 0 + - python3 offline_inference/vision_language_multi_image.py --seed 0 + - python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors + - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 + - python3 offline_inference/basic/classify.py + - python3 offline_inference/basic/embed.py + - python3 offline_inference/basic/score.py + - python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 + # https://github.com/vllm-project/vllm/pull/26682 uses slightly more memory in PyTorch 2.9+ causing this test to OOM in 1xL4 GPU + - python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 1536 + #- python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 + +- label: Platform Tests (CUDA) # 4min + timeout_in_minutes: 15 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - vllm/ + - tests/cuda + commands: + - pytest -v -s cuda/test_cuda_context.py + +- label: Samplers Test # 56min + timeout_in_minutes: 75 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - vllm/model_executor/layers + - vllm/sampling_metadata.py + - tests/samplers + - tests/conftest.py + commands: + - pytest -v -s samplers + - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers + +- label: LoRA Test %N # 20min each + timeout_in_minutes: 30 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_8 + # grade: Blocking + source_file_dependencies: + - vllm/lora + - tests/lora + commands: + - pytest -v -s lora \ + --shard-id=$$BUILDKITE_PARALLEL_JOB \ + --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ + --ignore=lora/test_chatglm3_tp.py \ + --ignore=lora/test_llama_tp.py \ + --ignore=lora/test_llm_with_multi_loras.py \ + --ignore=lora/test_olmoe_tp.py \ + --ignore=lora/test_deepseekv2_tp.py \ + --ignore=lora/test_gptoss.py \ + --ignore=lora/test_qwen3moe_tp.py + parallelism: 4 + +- label: PyTorch Compilation Unit Tests # 15min + timeout_in_minutes: 30 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/compile + commands: + - pytest -v -s compile/test_pass_manager.py + - pytest -v -s compile/test_fusion.py + - pytest -v -s compile/test_fusion_attn.py + - pytest -v -s compile/test_functionalization.py + - pytest -v -s compile/test_silu_mul_quant_fusion.py + # - pytest -v -s compile/test_sequence_parallelism.py + # - pytest -v -s compile/test_async_tp.py + - pytest -v -s compile/test_fusion_all_reduce.py + - pytest -v -s compile/test_decorator.py + - pytest -v -s compile/test_noop_elimination.py + - pytest -v -s compile/test_aot_compile.py + +- label: PyTorch Fullgraph Smoke Test # 15min + timeout_in_minutes: 30 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/compile + commands: + - pytest -v -s compile/test_basic_correctness.py + - pytest -v -s compile/piecewise/ + +- label: PyTorch Fullgraph Test # 22min + timeout_in_minutes: 35 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + # grade: Blocking + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/compile + commands: + - pytest -v -s compile/test_full_graph.py + - pytest -v -s compile/test_fusions_e2e.py + +- label: Kernels Core Operation Test # 48min + timeout_in_minutes: 75 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - csrc/ + - tests/kernels/core + - tests/kernels/test_top_k_per_row.py + commands: + - pytest -v -s kernels/core kernels/test_top_k_per_row.py + +- label: Kernels Attention Test %N # 23min + timeout_in_minutes: 35 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_8 + # grade: Blocking + source_file_dependencies: + - csrc/attention/ + - vllm/attention + - vllm/v1/attention + - tests/kernels/attention + commands: + - pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + parallelism: 2 + +- label: Kernels Quantization Test %N # 64min + timeout_in_minutes: 90 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_8 + # grade: Blocking + source_file_dependencies: + - csrc/quantization/ + - vllm/model_executor/layers/quantization + - tests/kernels/quantization + commands: + - pytest -v -s kernels/quantization --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + parallelism: 2 + +- label: Kernels MoE Test %N # 40min + timeout_in_minutes: 60 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_8 + # grade: Blocking + source_file_dependencies: + - csrc/quantization/cutlass_w8a8/moe/ + - csrc/moe/ + - tests/kernels/moe + - vllm/model_executor/layers/fused_moe/ + - vllm/distributed/device_communicators/ + commands: + - pytest -v -s kernels/moe --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + parallelism: 2 + +- label: Kernels Mamba Test # 31min + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - csrc/mamba/ + - tests/kernels/mamba + - vllm/model_executor/layers/mamba/ops + commands: + - pytest -v -s kernels/mamba + +- label: Model Executor Test # 23min + timeout_in_minutes: 35 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - vllm/model_executor + - tests/model_executor + - tests/entrypoints/openai/test_tensorizer_entrypoint.py + commands: + - apt-get update && apt-get install -y curl libsodium23 + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s model_executor + - pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py + +- label: Benchmarks # 11min + timeout_in_minutes: 20 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_8 + # grade: Blocking + working_dir: "/vllm-workspace/.buildkite" + source_file_dependencies: + - benchmarks/ + commands: + - bash scripts/run-benchmarks.sh + +- label: Benchmarks CLI Test # 7min + timeout_in_minutes: 20 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_8 + # grade: Blocking + source_file_dependencies: + - vllm/ + - tests/benchmarks/ + commands: + - pytest -v -s benchmarks/ + +- label: Quantization Test # 70min + timeout_in_minutes: 90 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + - tests/quantization + commands: + # temporary install here since we need nightly, will move to requirements/test.in + # after torchao 0.12 release, and pin a working version of torchao nightly here + + # since torchao nightly is only compatible with torch nightly currently + # https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now + # we can only upgrade after this is resolved + # TODO(jerryzh168): resolve the above comment + - uv pip install --system torchao==0.13.0 + - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py + +- label: LM Eval Small Models # 53min + timeout_in_minutes: 75 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 + +- label: OpenAI API correctness # 22min + timeout_in_minutes: 30 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - csrc/ + - vllm/entrypoints/openai/ + - vllm/model_executor/models/whisper.py + commands: # LMEval+Transcription WER check + - pytest -s entrypoints/openai/correctness/ + +- label: OpenAI-Compatible Tool Use # 23 min + timeout_in_minutes: 35 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + # grade: Blocking + fast_check: false + source_file_dependencies: + - vllm/ + - tests/tool_use + commands: + - pytest -v -s -m 'not cpu_test' tool_use + +- label: OpenAI-Compatible Tool Use (CPU) # 5 mins + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + # grade: Blocking + timeout_in_minutes: 10 + source_file_dependencies: + - vllm/ + - tests/tool_use + no_gpu: true + commands: + - pytest -v -s -m 'cpu_test' tool_use + +##### models test ##### + +- label: Basic Models Tests (Initialization) + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + # grade: Blocking + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/models/test_initialization.py + commands: + # Run a subset of model initialization tests + - pytest -v -s models/test_initialization.py::test_can_initialize_small_subset + +- label: Basic Models Tests (Extra Initialization) %N + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_8 + # grade: Blocking + torch_nightly: true + source_file_dependencies: + - vllm/model_executor/models/ + - tests/models/test_initialization.py + commands: + # Only when vLLM model source is modified - test initialization of a large + # subset of supported models (the complement of the small subset in the above + # test.) Also run if model initialization test file is modified + - pytest -v -s models/test_initialization.py \ + -k 'not test_can_initialize_small_subset' \ + --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ + --shard-id=$$BUILDKITE_PARALLEL_JOB + parallelism: 2 + +- label: Basic Models Tests (Other) + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/models/test_transformers.py + - tests/models/test_registry.py + commands: + - pytest -v -s models/test_transformers.py models/test_registry.py + +- label: Basic Models Test (Other CPU) # 5min + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + # grade: Blocking + timeout_in_minutes: 10 + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/models/test_utils.py + - tests/models/test_vision.py + no_gpu: true + commands: + - pytest -v -s models/test_utils.py models/test_vision.py + +- label: Language Models Tests (Standard) + timeout_in_minutes: 25 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/models/language + commands: + # Test standard language models, excluding a subset of slow tests + - pip freeze | grep -E 'torch' + - pytest -v -s models/language -m 'core_model and (not slow_test)' + +- label: Language Models Tests (Extra Standard) %N + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_8 + # grade: Blocking + torch_nightly: true + source_file_dependencies: + - vllm/model_executor/models/ + - tests/models/language/pooling/test_embedding.py + - tests/models/language/generation/test_common.py + - tests/models/language/pooling/test_classification.py + commands: + # Shard slow subset of standard language models tests. Only run when model + # source is modified, or when specified test files are modified + - pip freeze | grep -E 'torch' + - pytest -v -s models/language -m 'core_model and slow_test' \ + --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ + --shard-id=$$BUILDKITE_PARALLEL_JOB + parallelism: 2 + +- label: Language Models Tests (Hybrid) %N + timeout_in_minutes: 75 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_8 + # grade: Blocking + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/models/language/generation + commands: + # Install fast path packages for testing against transformers + # Note: also needed to run plamo2 model in vLLM + - uv pip install --system --no-build-isolation 'git+https://github.com/state-spaces/mamba@v2.2.5' + - uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.2' + # Shard hybrid language model tests + - pytest -v -s models/language/generation \ + -m hybrid_model \ + --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ + --shard-id=$$BUILDKITE_PARALLEL_JOB + parallelism: 2 + +- label: Language Models Test (Extended Generation) # 80min + timeout_in_minutes: 110 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/generation + commands: + # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. + - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' + - pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)' + +- label: Language Models Test (PPL) + timeout_in_minutes: 110 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/generation_ppl_test + commands: + - pytest -v -s models/language/generation_ppl_test + +- label: Language Models Test (Extended Pooling) # 36min + timeout_in_minutes: 50 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/pooling + commands: + - pytest -v -s models/language/pooling -m 'not core_model' + +- label: Language Models Test (MTEB) + timeout_in_minutes: 110 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/pooling_mteb_test + commands: + - pytest -v -s models/language/pooling_mteb_test + +- label: Multi-Modal Processor Test # 44min + timeout_in_minutes: 60 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/processing + +- label: Multi-Modal Models Test (Standard) # 60min + timeout_in_minutes: 80 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pip freeze | grep -E 'torch' + - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing + - cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work + +- label: Multi-Modal Accuracy Eval (Small Models) # 50min + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + timeout_in_minutes: 70 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - vllm/multimodal/ + - vllm/inputs/ + - vllm/v1/core/ + commands: + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1 + +- label: Multi-Modal Models Test (Extended) 1 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + optional: true + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal -m 'not core_model' --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing + +- label: Multi-Modal Models Test (Extended) 2 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + optional: true + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model' + +- label: Multi-Modal Models Test (Extended) 3 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + optional: true + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model' + +- label: Quantized Models Test # 45 min + timeout_in_minutes: 60 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + # grade: Blocking + source_file_dependencies: + - vllm/model_executor/layers/quantization + - tests/models/quantization + commands: + - pytest -v -s models/quantization + +# This test is used only in PR development phase to test individual models and should never run on main +- label: Custom Models Test + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + # grade: Blocking + optional: true + commands: + - echo 'Testing custom models...' + # PR authors can temporarily add commands below to test individual models + # e.g. pytest -v -s models/encoder_decoder/vision_language/test_mllama.py + # *To avoid merge conflicts, remember to REMOVE (not just comment out) them before merging the PR* + +- label: Transformers Nightly Models Test + mirror_hardwares: [amdexperimental] + agent_pool: mi325_1 + working_dir: "/vllm-workspace/" + optional: true + commands: + - pip install --upgrade git+https://github.com/huggingface/transformers + - pytest -v -s tests/models/test_initialization.py + - pytest -v -s tests/models/test_transformers.py + - pytest -v -s tests/models/multimodal/processing/ + - pytest -v -s tests/models/multimodal/test_mapping.py + - python3 examples/offline_inference/basic/chat.py + - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl + # Whisper needs spawn method to avoid deadlock + - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper + +- label: Blackwell Test # 21 min + timeout_in_minutes: 30 + working_dir: "/vllm-workspace/" + gpu: b200 + # optional: true + source_file_dependencies: + - csrc/quantization/fp4/ + - csrc/attention/mla/ + - csrc/quantization/cutlass_w8a8/moe/ + - vllm/model_executor/layers/fused_moe/cutlass_moe.py + - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py + - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py + - vllm/v1/attention/backends/flashinfer.py + commands: + - nvidia-smi + - python3 examples/offline_inference/basic/chat.py + # Attention + # num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353 + - pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2' + - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py + - pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py + - pytest -v -s tests/kernels/attention/test_flashinfer_mla_decode.py + # Quantization + - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' + - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py + - pytest -v -s tests/kernels/quantization/test_silu_mul_nvfp4_quant.py + - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py + - pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py + - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py + - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py + - pytest -v -s tests/kernels/moe/test_flashinfer.py + +- label: Blackwell Fusion Tests # 30 min + timeout_in_minutes: 40 + working_dir: "/vllm-workspace/" + gpu: b200 + source_file_dependencies: + - csrc/quantization/fp4/ + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py + - vllm/v1/attention/backends/flashinfer.py + - vllm/compilation/ + # can affect pattern matching + - vllm/model_executor/layers/layernorm.py + - vllm/model_executor/layers/activation.py + - vllm/model_executor/layers/quantization/input_quant_fp8.py + commands: + - nvidia-smi + - pytest -v -s tests/compile/test_fusion_attn.py + - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py + # this runner has 2 GPUs available even though num_gpus=2 is not set + - pytest -v -s tests/compile/test_fusion_all_reduce.py + - pytest -v -s tests/compile/test_fusions_e2e.py + +- label: Blackwell GPT-OSS Eval + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/" + gpu: b200 + optional: true # run on nightlies + source_file_dependencies: + - tests/evals/gpt_oss + - vllm/model_executor/models/gpt_oss.py + - vllm/model_executor/layers/quantization/mxfp4.py + - vllm/v1/attention/backends/flashinfer.py + commands: + - uv pip install --system 'gpt-oss[eval]==0.0.5' + - pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 + +- label: Blackwell Quantized MoE Test + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/" + gpu: b200 + source_file_dependencies: + - tests/quantization/test_blackwell_moe.py + - vllm/model_executor/models/deepseek_v2.py + - vllm/model_executor/models/gpt_oss.py + - vllm/model_executor/models/llama4.py + - vllm/model_executor/layers/fused_moe + - vllm/model_executor/layers/quantization/compressed_tensors + - vllm/model_executor/layers/quantization/modelopt.py + - vllm/model_executor/layers/quantization/mxfp4.py + - vllm/v1/attention/backends/flashinfer.py + commands: + - pytest -s -v tests/quantization/test_blackwell_moe.py + +- label: Blackwell LM Eval Small Models + timeout_in_minutes: 120 + gpu: b200 + optional: true # run on nightlies + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1 + +##### 1 GPU test ##### +##### multi gpus test ##### + +- label: Distributed Comm Ops Test # 7min + timeout_in_minutes: 20 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_2 + # grade: Blocking + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/distributed + - tests/distributed + commands: + - pytest -v -s distributed/test_comm_ops.py + - pytest -v -s distributed/test_shm_broadcast.py + - pytest -v -s distributed/test_shm_buffer.py + - pytest -v -s distributed/test_shm_storage.py + +- label: 2 Node Tests (4 GPUs in total) # 16min + timeout_in_minutes: 30 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_4 + # grade: Blocking + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + num_nodes: 2 + source_file_dependencies: + - vllm/distributed/ + - vllm/engine/ + - vllm/executor/ + - vllm/model_executor/models/ + - tests/distributed/ + - tests/examples/offline_inference/data_parallel.py + commands: + - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' + - NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' + - python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code + - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py + - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py + - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' + - NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' + - python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code + +- label: Distributed Tests (2 GPUs) # 68min + timeout_in_minutes: 90 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_2 + # grade: Blocking + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/compilation/ + - vllm/distributed/ + - vllm/engine/ + - vllm/executor/ + - vllm/worker/worker_base.py + - vllm/v1/engine/ + - vllm/v1/worker/ + - tests/compile/test_basic_correctness.py + - tests/compile/test_wrapper.py + - tests/distributed/ + - tests/entrypoints/llm/test_collective_rpc.py + - tests/v1/distributed + - tests/v1/entrypoints/openai/test_multi_api_servers.py + - tests/v1/shutdown + - tests/v1/worker/test_worker_memory_snapshot.py + commands: + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py + - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py + - pytest -v -s entrypoints/llm/test_collective_rpc.py + - pytest -v -s ./compile/test_basic_correctness.py + - pytest -v -s ./compile/test_wrapper.py + - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' + - VLLM_TEST_SAME_HOST=1 VLLM_TEST_WITH_DEFAULT_DEVICE_SET=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' + - pytest -v -s distributed/test_sequence_parallel.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown + - pytest -v -s v1/worker/test_worker_memory_snapshot.py + +- label: Distributed Model Tests (2 GPUs) # 37min + timeout_in_minutes: 50 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_2 + # grade: Blocking + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/model_executor/model_loader/sharded_state_loader.py + - vllm/model_executor/models/ + - tests/basic_correctness/ + - tests/model_executor/model_loader/test_sharded_state_loader.py + - tests/models/ + commands: + - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py + # Avoid importing model tests that cause CUDA reinitialization error + - pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)' + - pytest models/language -v -s -m 'distributed(num_gpus=2)' + - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py + - VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)' + +- label: Plugin Tests (2 GPUs) # 40min + timeout_in_minutes: 60 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_2 + # grade: Blocking + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/plugins/ + - tests/plugins/ + commands: + # begin platform plugin and general plugin tests, all the code in-between runs on dummy platform + - pip install -e ./plugins/vllm_add_dummy_platform + - pytest -v -s plugins_tests/test_platform_plugins.py + - pip uninstall vllm_add_dummy_platform -y + # end platform plugin tests + # begin io_processor plugins test, all the code in between uses the prithvi_io_processor plugin + - pip install -e ./plugins/prithvi_io_processor_plugin + - pytest -v -s plugins_tests/test_io_processor_plugins.py + - pip uninstall prithvi_io_processor_plugin -y + # end io_processor plugins test + # begin stat_logger plugins test + - pip install -e ./plugins/vllm_add_dummy_stat_logger + - pytest -v -s plugins_tests/test_stats_logger_plugins.py + - pip uninstall dummy_stat_logger -y + # end stat_logger plugins test + # other tests continue here: + - pytest -v -s plugins_tests/test_scheduler_plugins.py + - pip install -e ./plugins/vllm_add_dummy_model + - pytest -v -s distributed/test_distributed_oot.py + - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process + - pytest -v -s models/test_oot_registration.py # it needs a clean process + - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins + +- label: Pipeline + Context Parallelism Test # 45min + timeout_in_minutes: 60 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_4 + # grade: Blocking + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/ + - vllm/engine/ + - vllm/executor/ + - vllm/model_executor/models/ + - tests/distributed/ + commands: + - pytest -v -s distributed/test_pp_cudagraph.py + - pytest -v -s distributed/test_pipeline_parallel.py + +- label: LoRA TP Test (Distributed) # 17 min + timeout_in_minutes: 30 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_4 + # grade: Blocking + num_gpus: 4 + source_file_dependencies: + - vllm/lora + - tests/lora + commands: + # FIXIT: find out which code initialize cuda before running the test + # before the fix, we need to use spawn to test it + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + # There is some Tensor Parallelism related processing logic in LoRA that + # requires multi-GPU testing for validation. + - pytest -v -s -x lora/test_chatglm3_tp.py + - pytest -v -s -x lora/test_llama_tp.py + - pytest -v -s -x lora/test_llm_with_multi_loras.py + - pytest -v -s -x lora/test_olmoe_tp.py + +- label: Weight Loading Multiple GPU Test # 33min + timeout_in_minutes: 45 + mirror_hardwares: [amdexperimental] + agent_pool: mi325_2 + # grade: Blocking + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + optional: true + source_file_dependencies: + - vllm/ + - tests/weight_loading + commands: + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt + +- label: Weight Loading Multiple GPU Test - Large Models # optional + mirror_hardwares: [amdexperimental] + agent_pool: mi325_2 + # grade: Blocking + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + gpu: a100 + optional: true + source_file_dependencies: + - vllm/ + - tests/weight_loading + commands: + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt + +- label: NixlConnector PD accuracy tests (Distributed) # 30min + mirror_hardwares: [amdexperimental] + agent_pool: mi325_4 + timeout_in_minutes: 30 + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py + - tests/v1/kv_connector/nixl_integration/ + commands: + - uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt + - bash v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh + +##### multi gpus test ##### +##### A100 test ##### + +- label: Distributed Tests (A100) # optional + gpu: a100 + optional: true + num_gpus: 4 + source_file_dependencies: + - vllm/ + commands: + # NOTE: don't test llama model here, it seems hf implementation is buggy + # see https://github.com/vllm-project/vllm/pull/5689 for details + - pytest -v -s distributed/test_custom_all_reduce.py + - torchrun --nproc_per_node=2 distributed/test_ca_buffer_sharing.py + - TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' + - pytest -v -s -x lora/test_mixtral.py + +- label: LM Eval Large Models # optional + gpu: a100 + optional: true + num_gpus: 4 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 + +##### H200 test ##### +- label: Distributed Tests (H200) # optional + gpu: h200 + optional: true + working_dir: "/vllm-workspace/" + num_gpus: 2 + commands: + - pytest -v -s tests/compile/test_async_tp.py + - pytest -v -s tests/compile/test_sequence_parallelism.py + - pytest -v -s tests/compile/test_fusion_all_reduce.py + - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm + - pytest -v -s tests/distributed/test_context_parallel.py + - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 + +##### B200 test ##### +- label: Distributed Tests (B200) # optional + gpu: b200 + optional: true + working_dir: "/vllm-workspace/" + num_gpus: 2 + commands: + - pytest -v -s tests/distributed/test_context_parallel.py + - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py + +##### RL Integration Tests ##### +- label: Prime-RL Integration Test # 15min + mirror_hardwares: [amdexperimental] + agent_pool: mi325_2 + # grade: Blocking + timeout_in_minutes: 30 + optional: true + num_gpus: 2 + working_dir: "/vllm-workspace" + source_file_dependencies: + - vllm/ + - .buildkite/scripts/run-prime-rl-test.sh + commands: + - bash .buildkite/scripts/run-prime-rl-test.sh diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 48dff31c14dc..3f1d50d55810 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -172,6 +172,8 @@ steps: - tests/v1/engine/test_engine_core_client.py - tests/distributed/test_symm_mem_allreduce.py commands: + # https://github.com/NVIDIA/nccl/issues/1838 + - export NCCL_CUMEM_HOST_ENABLE=0 # test with torchrun tp=2 and external_dp=2 - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py # test with torchrun tp=2 and pp=2 @@ -349,7 +351,8 @@ steps: - python3 offline_inference/basic/embed.py - python3 offline_inference/basic/score.py - python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 - - python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048 + # https://github.com/vllm-project/vllm/pull/26682 uses slightly more memory in PyTorch 2.9+ causing this test to OOM in 1xL4 GPU + - python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 1536 - label: Platform Tests (CUDA) # 4min timeout_in_minutes: 15 @@ -384,7 +387,12 @@ steps: --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ --ignore=lora/test_chatglm3_tp.py \ --ignore=lora/test_llama_tp.py \ - --ignore=lora/test_llm_with_multi_loras.py + --ignore=lora/test_llm_with_multi_loras.py \ + --ignore=lora/test_olmoe_tp.py \ + --ignore=lora/test_deepseekv2_tp.py \ + --ignore=lora/test_gptoss.py \ + --ignore=lora/test_qwen3moe_tp.py + parallelism: 4 - label: PyTorch Compilation Unit Tests # 15min @@ -400,11 +408,10 @@ steps: - pytest -v -s compile/test_fusion_attn.py - pytest -v -s compile/test_functionalization.py - pytest -v -s compile/test_silu_mul_quant_fusion.py - - pytest -v -s compile/test_sequence_parallelism.py - - pytest -v -s compile/test_async_tp.py - pytest -v -s compile/test_fusion_all_reduce.py - pytest -v -s compile/test_decorator.py - pytest -v -s compile/test_noop_elimination.py + - pytest -v -s compile/test_aot_compile.py - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 @@ -417,8 +424,8 @@ steps: - pytest -v -s compile/test_basic_correctness.py - pytest -v -s compile/piecewise/ -- label: PyTorch Fullgraph Test # 20min - timeout_in_minutes: 30 +- label: PyTorch Fullgraph Test # 22min + timeout_in_minutes: 35 mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: @@ -426,6 +433,19 @@ steps: - tests/compile commands: - pytest -v -s compile/test_full_graph.py + - pytest -v -s compile/test_fusions_e2e.py + +- label: Cudagraph test + timeout_in_minutes: 20 + mirror_hardwares: [amdexperimental] + source_file_dependencies: + - tests/v1/cudagraph + - vllm/v1/cudagraph_dispatcher.py + - vllm/config/compilation.py + - vllm/compilation + commands: + - pytest -v -s v1/cudagraph/test_cudagraph_dispatch.py + - pytest -v -s v1/cudagraph/test_cudagraph_mode.py - label: Kernels Core Operation Test # 48min timeout_in_minutes: 75 @@ -433,8 +453,9 @@ steps: source_file_dependencies: - csrc/ - tests/kernels/core + - tests/kernels/test_top_k_per_row.py commands: - - pytest -v -s kernels/core + - pytest -v -s kernels/core kernels/test_top_k_per_row.py - label: Kernels Attention Test %N # 23min timeout_in_minutes: 35 @@ -527,8 +548,9 @@ steps: # since torchao nightly is only compatible with torch nightly currently # https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now # we can only upgrade after this is resolved - - pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128 - - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ + # TODO(jerryzh168): resolve the above comment + - uv pip install --system torchao==0.13.0 --index-url https://download.pytorch.org/whl/cu129 + - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py - label: LM Eval Small Models # 53min timeout_in_minutes: 75 @@ -733,6 +755,16 @@ steps: - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing - cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work +- label: Multi-Modal Accuracy Eval (Small Models) # 50min + timeout_in_minutes: 70 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - vllm/multimodal/ + - vllm/inputs/ + - vllm/v1/core/ + commands: + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1 + - label: Multi-Modal Models Test (Extended) 1 mirror_hardwares: [amdexperimental] optional: true @@ -796,8 +828,8 @@ steps: # Whisper needs spawn method to avoid deadlock - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper -- label: Blackwell Test # 38 min - timeout_in_minutes: 60 +- label: Blackwell Test # 21 min + timeout_in_minutes: 30 working_dir: "/vllm-workspace/" gpu: b200 # optional: true @@ -810,8 +842,6 @@ steps: - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py - - vllm/compilation/fusion.py - - vllm/compilation/fusion_attn.py commands: - nvidia-smi - python3 examples/offline_inference/basic/chat.py @@ -828,13 +858,32 @@ steps: - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py + - pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py - pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py - # Fusion - - pytest -v -s tests/compile/test_fusion_all_reduce.py - - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern - pytest -v -s tests/kernels/moe/test_flashinfer.py + +- label: Blackwell Fusion Tests # 30 min + timeout_in_minutes: 40 + working_dir: "/vllm-workspace/" + gpu: b200 + source_file_dependencies: + - csrc/quantization/fp4/ + - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py + - vllm/v1/attention/backends/flashinfer.py + - vllm/compilation/ + # can affect pattern matching + - vllm/model_executor/layers/layernorm.py + - vllm/model_executor/layers/activation.py + - vllm/model_executor/layers/quantization/input_quant_fp8.py + commands: + - nvidia-smi + - pytest -v -s tests/compile/test_fusion_attn.py - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py + # this runner has 2 GPUs available even though num_gpus=2 is not set + - pytest -v -s tests/compile/test_fusion_all_reduce.py + - pytest -v -s tests/compile/test_fusions_e2e.py - label: Blackwell GPT-OSS Eval timeout_in_minutes: 60 @@ -941,6 +990,8 @@ steps: - tests/v1/shutdown - tests/v1/worker/test_worker_memory_snapshot.py commands: + # https://github.com/NVIDIA/nccl/issues/1838 + - export NCCL_CUMEM_HOST_ENABLE=0 - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py @@ -948,6 +999,7 @@ steps: - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' + - VLLM_TEST_SAME_HOST=1 VLLM_TEST_WITH_DEFAULT_DEVICE_SET=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' - pytest -v -s distributed/test_sequence_parallel.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - pytest -v -s v1/worker/test_worker_memory_snapshot.py @@ -991,6 +1043,11 @@ steps: - pytest -v -s plugins_tests/test_io_processor_plugins.py - pip uninstall prithvi_io_processor_plugin -y # end io_processor plugins test + # begin stat_logger plugins test + - pip install -e ./plugins/vllm_add_dummy_stat_logger + - pytest -v -s plugins_tests/test_stats_logger_plugins.py + - pip uninstall dummy_stat_logger -y + # end stat_logger plugins test # other tests continue here: - pytest -v -s plugins_tests/test_scheduler_plugins.py - pip install -e ./plugins/vllm_add_dummy_model @@ -1030,6 +1087,7 @@ steps: - pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_llama_tp.py - pytest -v -s -x lora/test_llm_with_multi_loras.py + - pytest -v -s -x lora/test_olmoe_tp.py - label: Weight Loading Multiple GPU Test # 33min @@ -1055,6 +1113,17 @@ steps: - tests/weight_loading commands: - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt + +- label: NixlConnector PD accuracy tests (Distributed) # 30min + timeout_in_minutes: 30 + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py + - tests/v1/kv_connector/nixl_integration/ + commands: + - uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt + - bash v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh ##### multi gpus test ##### @@ -1087,12 +1156,16 @@ steps: - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 ##### H200 test ##### -- label: Distrubted Tests (H200) # optional +- label: Distributed Tests (H200) # optional gpu: h200 optional: true working_dir: "/vllm-workspace/" num_gpus: 2 commands: + - pytest -v -s tests/compile/test_async_tp.py + - pytest -v -s tests/compile/test_sequence_parallelism.py + - pytest -v -s tests/compile/test_fusion_all_reduce.py + - pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm - pytest -v -s tests/distributed/test_context_parallel.py - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 diff --git a/.coveragerc b/.coveragerc index bc6342956109..b7a9fdb4e05a 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,5 +1,10 @@ [run] -source = vllm +# Track the installed vllm package (this is what actually gets imported during tests) +# Use wildcard pattern to match the installed location +source = + vllm + */dist-packages/vllm + */site-packages/vllm omit = */tests/* */test_* @@ -12,6 +17,16 @@ omit = */benchmarks/* */docs/* +[paths] +# Map all possible vllm locations to a canonical "vllm" path +# This ensures coverage.combine properly merges data from different test runs +source = + vllm + /vllm-workspace/src/vllm + /vllm-workspace/vllm + */site-packages/vllm + */dist-packages/vllm + [report] exclude_lines = pragma: no cover diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 000000000000..5a601d00cef8 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,4 @@ +# Migrate from `yapf` & `isort` to `ruff` +d6953beb91da4e9c99be4c0a1304a2d24189535c +# Convert `Optional[x]` to `x | None` and `Union[x, y]` to `x | y` +8fcaaf6a165e661f63fc51be906bc05b0767332f diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index dbcad3aa308f..ba08a4335215 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -5,10 +5,8 @@ /vllm/attention @LucasWilkinson /vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn -/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn -/vllm/model_executor/layers/fused_moe @mgoin -/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @NickLucche -/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 +/vllm/model_executor/layers/fused_moe @mgoin @pavanimajety +/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 @pavanimajety /vllm/model_executor/layers/mamba @tdoublep /vllm/model_executor/model_loader @22quinn /vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @@ -26,9 +24,9 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /vllm/config/cache.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg @heheda12345 # vLLM V1 -/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat /vllm/v1/attention @LucasWilkinson -/vllm/v1/attention/backends/flashinfer.py @mgoin +/vllm/v1/attention/backends/mla @pavanimajety +/vllm/v1/attention/backends/flashinfer.py @mgoin @pavanimajety /vllm/v1/attention/backends/triton_attn.py @tdoublep /vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC /vllm/v1/sample @22quinn @houseroad @njhill @@ -47,7 +45,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256 /tests/models @DarkLight1337 @ywang96 /tests/multimodal @DarkLight1337 @ywang96 @NickLucche -/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 +/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256 @pavanimajety /tests/test_inputs.py @DarkLight1337 @ywang96 /tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb @aarnphm /tests/v1/structured_output @mgoin @russellb @aarnphm @@ -60,7 +58,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /tests/v1/offloading @ApostaC # Transformers backend -/vllm/model_executor/models/transformers.py @hmellor +/vllm/model_executor/models/transformers @hmellor /tests/models/test_transformers.py @hmellor # Docs @@ -121,3 +119,11 @@ mkdocs.yaml @hmellor # KVConnector installation files /requirements/kv_connectors.txt @NickLucche + +# Pooling models +/examples/*/pooling/ @noooop +/tests/models/*/pooling* @noooop +/tests/entrypoints/pooling @noooop +/vllm/config/pooler.py @noooop +/vllm/pooling_params.py @noooop +/vllm/model_executor/layers/pooler.py @noooop diff --git a/.github/workflows/issue_autolabel.yml b/.github/workflows/issue_autolabel.yml index c2b17abe811c..7d565ef9f2e4 100644 --- a/.github/workflows/issue_autolabel.yml +++ b/.github/workflows/issue_autolabel.yml @@ -13,6 +13,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Label issues based on keywords + id: label-step uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: script: | @@ -42,7 +43,6 @@ jobs: searchIn: "body" }, ], - // Substring search - matches anywhere in text (partial matches) substrings: [ { @@ -89,14 +89,12 @@ jobs: term: "hip_", searchIn: "both" }, - // ROCm tools and libraries { term: "hipify", searchIn: "both" }, ], - // Regex patterns - for complex pattern matching regexPatterns: [ { @@ -107,13 +105,17 @@ jobs: } ], }, + // Add more label configurations here as needed + // example: { + // keywords: [...], + // substrings: [...], + // regexPatterns: [...] + // }, }; - // Helper function to create regex based on search type function createSearchRegex(term, type) { // Escape special regex characters in the term const escapedTerm = term.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); - switch (type) { case 'keyword': // Word boundary search - matches whole words only @@ -125,16 +127,13 @@ jobs: throw new Error(`Unknown search type: ${type}`); } } - // Helper function to find matching terms in text with line information function findMatchingTermsWithLines(text, searchTerms = [], searchType = 'keyword', searchLocation = '') { const matches = []; const lines = text.split('\n'); - for (const termConfig of searchTerms) { let regex; let term, searchIn, pattern, description, flags; - // Handle different input formats (string or object) if (typeof termConfig === 'string') { term = termConfig; @@ -146,21 +145,17 @@ jobs: description = termConfig.description; flags = termConfig.flags; } - // Skip if this term shouldn't be searched in the current location if (searchIn !== 'both' && searchIn !== searchLocation) { continue; } - // Create appropriate regex if (searchType === 'regex') { regex = new RegExp(pattern, flags || "gi"); } else { regex = createSearchRegex(term, searchType); } - const termMatches = []; - // Check each line for matches lines.forEach((line, lineIndex) => { const lineMatches = line.match(regex); @@ -175,15 +170,14 @@ jobs: originalTerm: term || pattern, description: description, // Show context around the match in the line - context: line.length > 100 ? - line.substring(Math.max(0, line.toLowerCase().indexOf(match.toLowerCase()) - 30), - line.toLowerCase().indexOf(match.toLowerCase()) + match.length + 30) + '...' + context: line.length > 100 ? + line.substring(Math.max(0, line.toLowerCase().indexOf(match.toLowerCase()) - 30), + line.toLowerCase().indexOf(match.toLowerCase()) + match.length + 30) + '...' : line.trim() }); }); } }); - if (termMatches.length > 0) { matches.push({ term: term || (description || pattern), @@ -196,64 +190,48 @@ jobs: }); } } - return matches; } - // Helper function to check if label should be added async function processLabel(labelName, config) { const body = context.payload.issue.body || ""; const title = context.payload.issue.title || ""; - core.notice(`Processing label: ${labelName}`); core.notice(`Issue Title: "${title}"`); core.notice(`Issue Body length: ${body.length} characters`); - let shouldAddLabel = false; let allMatches = []; let reason = ''; - const keywords = config.keywords || []; const substrings = config.substrings || []; const regexPatterns = config.regexPatterns || []; - core.notice(`Searching with ${keywords.length} keywords, ${substrings.length} substrings, and ${regexPatterns.length} regex patterns`); - // Search in title if (title.trim()) { core.notice(`Searching in title: "${title}"`); - const titleKeywordMatches = findMatchingTermsWithLines(title, keywords, 'keyword', 'title'); const titleSubstringMatches = findMatchingTermsWithLines(title, substrings, 'substring', 'title'); const titleRegexMatches = findMatchingTermsWithLines(title, regexPatterns, 'regex', 'title'); - allMatches.push(...titleKeywordMatches, ...titleSubstringMatches, ...titleRegexMatches); } - // Search in body if (body.trim()) { core.notice(`Searching in body (${body.length} characters)`); - const bodyKeywordMatches = findMatchingTermsWithLines(body, keywords, 'keyword', 'body'); const bodySubstringMatches = findMatchingTermsWithLines(body, substrings, 'substring', 'body'); const bodyRegexMatches = findMatchingTermsWithLines(body, regexPatterns, 'regex', 'body'); - allMatches.push(...bodyKeywordMatches, ...bodySubstringMatches, ...bodyRegexMatches); } - if (allMatches.length > 0) { core.notice(`Found ${allMatches.length} matching term(s):`); - for (const termMatch of allMatches) { const locationText = termMatch.searchLocation === 'title' ? 'title' : 'body'; const searchInText = termMatch.searchIn === 'both' ? 'both' : termMatch.searchIn; - if (termMatch.searchType === 'regex') { core.notice(` 📍 Regex: "${termMatch.term}" (pattern: ${termMatch.pattern}) found ${termMatch.count} time(s) in ${locationText} (configured to search in: ${searchInText}):`); } else { core.notice(` 📍 Term: "${termMatch.term}" (${termMatch.searchType} search) found ${termMatch.count} time(s) in ${locationText} (configured to search in: ${searchInText}):`); } - // Show details for each match termMatch.matches.forEach((match, index) => { core.notice(` ${index + 1}. Line ${match.lineNumber} in ${match.searchLocation}: "${match.match}" [${match.searchType}]`); @@ -266,7 +244,6 @@ jobs: } }); } - shouldAddLabel = true; const totalMatches = allMatches.reduce((sum, t) => sum + t.count, 0); const titleMatches = allMatches.filter(t => t.searchLocation === 'title').reduce((sum, t) => sum + t.count, 0); @@ -274,13 +251,10 @@ jobs: const keywordMatches = allMatches.filter(t => t.searchType === 'keyword').reduce((sum, t) => sum + t.count, 0); const substringMatches = allMatches.filter(t => t.searchType === 'substring').reduce((sum, t) => sum + t.count, 0); const regexMatches = allMatches.filter(t => t.searchType === 'regex').reduce((sum, t) => sum + t.count, 0); - reason = `Found ${totalMatches} total matches (${titleMatches} in title, ${bodyMatches} in body) - ${keywordMatches} keyword matches, ${substringMatches} substring matches, ${regexMatches} regex matches`; } - core.notice(`Final decision: ${shouldAddLabel ? 'ADD LABEL' : 'DO NOT ADD LABEL'}`); core.notice(`Reason: ${reason || 'No matching terms found'}`); - if (shouldAddLabel) { const existingLabels = context.payload.issue.labels.map(l => l.name); if (!existingLabels.includes(labelName)) { @@ -296,14 +270,92 @@ jobs: core.notice(`Label "${labelName}" already present.`); return false; } - core.notice(`No matching terms found for label "${labelName}".`); return false; } - // Process all configured labels - const processLabels = Object.entries(labelConfig) - .map(([labelName, config]) => processLabel(labelName, config)); - const labelsAdded = await Promise.all(processLabels); - const numLabelsAdded = labelsAdded.reduce((x, y) => x + y, 0); - core.notice(`Processing complete. ${numLabelsAdded} label(s) added.`); \ No newline at end of file + const labelsAddedResults = await Promise.all( + Object.entries(labelConfig).map(([labelName, config]) => + processLabel(labelName, config).then(added => ({ labelName, added })) + ) + ); + + const numLabelsAdded = labelsAddedResults.filter(r => r.added).length; + core.notice(`Processing complete. ${numLabelsAdded} label(s) added.`); + + // Return which labels were added for the next step + const addedLabels = labelsAddedResults.filter(r => r.added).map(r => r.labelName); + core.setOutput('labels_added', JSON.stringify(addedLabels)); + return addedLabels; + + - name: CC users for labeled issues + if: steps.label-step.outputs.labels_added != '[]' + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + with: + script: | + // Configuration: Map labels to GitHub users to CC + // You can add multiple users per label, and multiple label configurations + const ccConfig = { + rocm: { + users: ['hongxiayang', 'tjtanaa', 'vllmellm'], // Add more users as needed: ['user1', 'user2', 'user3'] + message: 'CC {users} for ROCm-related issue' // {users} will be replaced with @mentions + }, + // Add more label -> user mappings here + // Example: + // cuda: { + // users: ['user1', 'user2'], + // message: 'CC {users} for CUDA-related issue' + // }, + // performance: { + // users: ['perfexpert'], + // message: 'CC {users} for performance issue' + // }, + }; + + const labelsAdded = JSON.parse('${{ steps.label-step.outputs.labels_added }}'); + core.notice(`Labels added: ${labelsAdded.join(', ')}`); + + // Get existing comments to check for already mentioned users + const comments = await github.rest.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + }); + + const issueBody = context.payload.issue.body || ''; + const allExistingText = issueBody + '\n' + comments.data.map(c => c.body).join('\n'); + + // Process each label that was added + for (const label of labelsAdded) { + if (ccConfig[label]) { + const config = ccConfig[label]; + const usersToMention = []; + + // Check which users haven't been mentioned yet + for (const user of config.users) { + const mentionPattern = new RegExp(`@${user}\\b`, 'i'); + if (!mentionPattern.test(allExistingText)) { + usersToMention.push(user); + } else { + core.notice(`@${user} already mentioned for label "${label}", skipping`); + } + } + + // Post comment if there are users to mention + if (usersToMention.length > 0) { + const mentions = usersToMention.map(u => `@${u}`).join(' '); + const message = config.message.replace('{users}', mentions); + + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: message + }); + + core.notice(`CC comment added for label "${label}": ${mentions}`); + } else { + core.notice(`All users for label "${label}" already mentioned, skipping comment`); + } + } + } \ No newline at end of file diff --git a/.gitignore b/.gitignore index b1df673e83ca..ffa36dee1ab9 100644 --- a/.gitignore +++ b/.gitignore @@ -94,6 +94,9 @@ ipython_config.py # generated files **/generated/** +# uv +uv.lock + # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: diff --git a/.markdownlint.yaml b/.markdownlint.yaml index c86fed9555d6..cd9df57cd980 100644 --- a/.markdownlint.yaml +++ b/.markdownlint.yaml @@ -4,7 +4,6 @@ MD013: false MD024: siblings_only: true MD033: false -MD042: false MD045: false MD046: false MD051: false diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 95a3866e6bb8..fbfd8016cb76 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,17 +7,18 @@ default_stages: exclude: 'vllm/third_party/.*' repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.13.3 + rev: v0.14.0 hooks: - id: ruff-check args: [--output-format, github, --fix] - id: ruff-format - repo: https://github.com/crate-ci/typos - rev: v1.35.5 + rev: v1.38.1 hooks: - id: typos + args: [--force-exclude] - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v20.1.3 + rev: v21.1.2 hooks: - id: clang-format exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*' @@ -34,10 +35,10 @@ repos: hooks: - id: actionlint - repo: https://github.com/astral-sh/uv-pre-commit - rev: 0.6.17 + rev: 0.9.1 hooks: - id: pip-compile - args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128, --python-platform, x86_64-manylinux_2_28] + args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28] files: ^requirements/test\.(in|txt)$ - repo: local hooks: @@ -47,19 +48,14 @@ repos: entry: python tools/generate_nightly_torch_test.py files: ^requirements/test\.(in|txt)$ - id: mypy-local - name: Run mypy for local Python installation - entry: python tools/pre_commit/mypy.py 0 "local" + name: Run mypy locally for lowest supported Python version + entry: python tools/pre_commit/mypy.py 0 "3.10" stages: [pre-commit] # Don't run in CI <<: &mypy_common language: python types_or: [python, pyi] require_serial: true additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic] - - id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - name: Run mypy for Python 3.9 - entry: python tools/pre_commit/mypy.py 1 "3.9" - <<: *mypy_common - stages: [manual] # Only run in CI - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.10 entry: python tools/pre_commit/mypy.py 1 "3.10" @@ -75,6 +71,11 @@ repos: entry: python tools/pre_commit/mypy.py 1 "3.12" <<: *mypy_common stages: [manual] # Only run in CI + - id: mypy-3.13 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.13 + entry: python tools/pre_commit/mypy.py 1 "3.13" + <<: *mypy_common + stages: [manual] # Only run in CI - id: shellcheck name: Lint shell scripts entry: tools/shellcheck.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 5ebea1c42e9a..7cb94f919f12 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,7 +34,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) # Supported python versions. These versions will be searched in order, the # first match will be selected. These should be kept in sync with setup.py. # -set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12" "3.13") +set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13") # Supported AMD GPU architectures. set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151") @@ -49,8 +49,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1 # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from docker/Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.8.0") -set(TORCH_SUPPORTED_VERSION_ROCM "2.8.0") +set(TORCH_SUPPORTED_VERSION_CUDA "2.9.0") +set(TORCH_SUPPORTED_VERSION_ROCM "2.9.0") # # Try to find python package with an executable that exactly matches @@ -269,8 +269,8 @@ set(VLLM_EXT_SRC "csrc/sampler.cu" "csrc/cuda_view.cu" "csrc/quantization/gptq/q_gemm.cu" - "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" - "csrc/quantization/fp8/common.cu" + "csrc/quantization/w8a8/int8/scaled_quant.cu" + "csrc/quantization/w8a8/fp8/common.cu" "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/activation_kernels.cu" @@ -314,12 +314,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC "csrc/quantization/awq/gemm_kernels.cu" "csrc/permute_cols.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp" - "csrc/quantization/fp8/per_token_group_quant.cu") + "csrc/quantization/w8a8/fp8/per_token_group_quant.cu" + "csrc/quantization/w8a8/int8/per_token_group_quant.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" @@ -423,11 +424,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) set(SRCS - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu") + "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -458,9 +459,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) set(SRCS - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu" + "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu" ) set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -492,9 +493,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) set(SRCS - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu" - "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu" + "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu" + "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu" ) set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -525,7 +526,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # subtract out the archs that are already built for 3x list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) if (SCALED_MM_2X_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_2X_ARCHS}") @@ -648,7 +649,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # if it's possible to compile MoE kernels that use its output. cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -672,7 +673,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -697,7 +698,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}") @@ -720,7 +721,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu") + set(SRCS "csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -882,6 +883,7 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" "csrc/moe/moe_align_sum_kernels.cu" + "csrc/moe/moe_lora_align_sum_kernels.cu" "csrc/moe/topk_softmax_kernels.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") @@ -1006,6 +1008,7 @@ endif() # For CUDA we also build and ship some external projects. if (VLLM_GPU_LANG STREQUAL "CUDA") include(cmake/external_projects/flashmla.cmake) + include(cmake/external_projects/qutlass.cmake) # vllm-flash-attn should be last as it overwrites some CMake functions include(cmake/external_projects/vllm_flash_attn.cmake) diff --git a/README.md b/README.md index 6772a9eae073..3dcdd7dc0094 100644 --- a/README.md +++ b/README.md @@ -149,6 +149,7 @@ Compute Resources: - Trainy - UC Berkeley - UC San Diego +- Volcengine Slack Sponsor: Anyscale diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index ba7c733be0b2..4021fede7215 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -8,7 +8,6 @@ import time import traceback from dataclasses import dataclass, field -from typing import Optional, Union import aiohttp import huggingface_hub.constants @@ -28,13 +27,13 @@ class RequestFuncInput: prompt_len: int output_len: int model: str - model_name: Optional[str] = None - logprobs: Optional[int] = None - extra_body: Optional[dict] = None - multi_modal_content: Optional[dict | list[dict]] = None + model_name: str | None = None + logprobs: int | None = None + extra_body: dict | None = None + multi_modal_content: dict | list[dict] | None = None ignore_eos: bool = False - language: Optional[str] = None - request_id: Optional[str] = None + language: str | None = None + request_id: str | None = None @dataclass @@ -52,7 +51,7 @@ class RequestFuncOutput: async def async_request_tgi( request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith("generate_stream") @@ -133,7 +132,7 @@ async def async_request_tgi( async def async_request_trt_llm( request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith("generate_stream") @@ -204,7 +203,7 @@ async def async_request_trt_llm( async def async_request_deepspeed_mii( request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith(("completions", "profile")), ( @@ -267,7 +266,7 @@ async def async_request_deepspeed_mii( async def async_request_openai_completions( request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith(("completions", "profile")), ( @@ -367,7 +366,7 @@ async def async_request_openai_completions( async def async_request_openai_chat_completions( request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith(("chat/completions", "profile")), ( @@ -476,7 +475,7 @@ async def async_request_openai_chat_completions( async def async_request_openai_audio( request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, ) -> RequestFuncOutput: # Lazy import without PlaceholderModule to avoid vllm dep. import soundfile @@ -610,7 +609,7 @@ def get_tokenizer( tokenizer_mode: str = "auto", trust_remote_code: bool = False, **kwargs, -) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: +) -> PreTrainedTokenizer | PreTrainedTokenizerFast: if pretrained_model_name_or_path is not None and not os.path.exists( pretrained_model_name_or_path ): diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index b5e2613de1cd..d7dc0e991c4d 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -32,7 +32,6 @@ import json import random import time -from typing import Optional from transformers import PreTrainedTokenizerBase @@ -80,7 +79,7 @@ def sample_requests_from_dataset( num_requests: int, tokenizer: PreTrainedTokenizerBase, input_length_range: tuple[int, int], - fixed_output_len: Optional[int], + fixed_output_len: int | None, ) -> list[Request]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") @@ -128,7 +127,7 @@ def sample_requests_from_random( num_requests: int, tokenizer: PreTrainedTokenizerBase, input_length_range: tuple[int, int], - fixed_output_len: Optional[int], + fixed_output_len: int | None, prefix_len: int, ) -> list[Request]: requests = [] diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py index bb453791c186..769f52dbab6e 100644 --- a/benchmarks/benchmark_prioritization.py +++ b/benchmarks/benchmark_prioritization.py @@ -7,7 +7,6 @@ import json import random import time -from typing import Optional from transformers import AutoTokenizer, PreTrainedTokenizerBase @@ -24,7 +23,7 @@ def sample_requests( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, - fixed_output_len: Optional[int], + fixed_output_len: int | None, ) -> list[tuple[str, int, int, int]]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index 58b9767d0939..539ab2ed0a4d 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -31,8 +31,8 @@ import uuid import warnings from collections.abc import AsyncGenerator +from contextlib import nullcontext from dataclasses import dataclass -from typing import Optional import datasets import numpy as np @@ -316,7 +316,7 @@ def calculate_metrics( tokenizer: PreTrainedTokenizerBase, selected_percentile_metrics: list[str], selected_percentiles: list[float], - goodput_config_dict: Optional[dict[str, float]] = None, + goodput_config_dict: dict[str, float] | None = None, ) -> tuple[BenchmarkMetrics, list[int]]: actual_output_lens: list[int] = [] total_input = 0 @@ -436,9 +436,9 @@ async def benchmark( selected_percentile_metrics: list[str], selected_percentiles: list[str], ignore_eos: bool, - max_concurrency: Optional[int], + max_concurrency: int | None, structured_output_ratio: float, - goodput_config_dict: Optional[dict[str, float]] = None, + goodput_config_dict: dict[str, float] | None = None, ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -502,15 +502,9 @@ def prepare_extra_body(request) -> dict: pbar = None if disable_tqdm else tqdm(total=len(input_requests)) - # This can be used once the minimum Python version is 3.10 or higher, - # and it will simplify the code in limited_request_func. - # semaphore = (asyncio.Semaphore(max_concurrency) - # if max_concurrency else contextlib.nullcontext()) - semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None + semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else nullcontext() async def limited_request_func(request_func_input, pbar): - if semaphore is None: - return await request_func(request_func_input=request_func_input, pbar=pbar) async with semaphore: return await request_func(request_func_input=request_func_input, pbar=pbar) diff --git a/benchmarks/benchmark_utils.py b/benchmarks/benchmark_utils.py index 98624abdf49f..f0d661f9d534 100644 --- a/benchmarks/benchmark_utils.py +++ b/benchmarks/benchmark_utils.py @@ -6,7 +6,7 @@ import os import time from types import TracebackType -from typing import Any, Optional, Union +from typing import Any def convert_to_pytorch_benchmark_format( @@ -92,7 +92,7 @@ class TimeCollector: def __init__(self, scale: int) -> None: self.cnt: int = 0 self._sum: int = 0 - self._max: Optional[int] = None + self._max: int | None = None self.scale = scale self.start_time: int = time.monotonic_ns() @@ -104,13 +104,13 @@ def collect(self, v: int) -> None: else: self._max = max(self._max, v) - def avg(self) -> Union[float, str]: + def avg(self) -> float | str: return self._sum * 1.0 / self.cnt / self.scale if self.cnt > 0 else "N/A" - def max(self) -> Union[float, str]: + def max(self) -> float | str: return self._max / self.scale if self._max else "N/A" - def dump_avg_max(self) -> list[Union[float, str]]: + def dump_avg_max(self) -> list[float | str]: return [self.avg(), self.max()] def __enter__(self) -> None: @@ -118,8 +118,8 @@ def __enter__(self) -> None: def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - exc_traceback: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_traceback: TracebackType | None, ) -> None: self.collect(time.monotonic_ns() - self.start_time) diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py index 9ec270bbd2e9..22fc2678fd1c 100644 --- a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -6,8 +6,7 @@ import itertools import pickle as pkl import time -from collections.abc import Iterable -from typing import Callable +from collections.abc import Callable, Iterable import torch import torch.utils.benchmark as TBenchmark diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 02f8c593392c..2deebf3ddb7a 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -6,8 +6,7 @@ import itertools import pickle as pkl import time -from collections.abc import Iterable -from typing import Callable, Optional +from collections.abc import Callable, Iterable import torch import torch.utils.benchmark as TBenchmark @@ -53,7 +52,7 @@ def bench_int8( n: int, label: str, sub_label: str, - bench_kernels: Optional[list[str]] = None, + bench_kernels: list[str] | None = None, ) -> Iterable[TMeasurement]: """Benchmark INT8-based kernels.""" assert dtype == torch.int8 @@ -108,7 +107,7 @@ def bench_fp8( n: int, label: str, sub_label: str, - bench_kernels: Optional[list[str]] = None, + bench_kernels: list[str] | None = None, ) -> Iterable[TMeasurement]: """Benchmark FP8-based kernels.""" assert dtype == torch.float8_e4m3fn @@ -183,7 +182,7 @@ def bench( n: int, label: str, sub_label: str, - bench_kernels: Optional[list[str]] = None, + bench_kernels: list[str] | None = None, ) -> Iterable[TMeasurement]: if dtype == torch.int8: return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels) @@ -201,7 +200,7 @@ def print_timers(timers: Iterable[TMeasurement]): def run( dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]], - bench_kernels: Optional[list[str]] = None, + bench_kernels: list[str] | None = None, ) -> Iterable[TMeasurement]: results = [] for m, k, n in MKNs: diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py index 901524214469..d809bf1db8cb 100644 --- a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -3,10 +3,9 @@ import pickle as pkl import time -from collections.abc import Iterable +from collections.abc import Callable, Iterable from dataclasses import dataclass from itertools import product -from typing import Callable, Optional import torch import torch.utils.benchmark as TBenchmark @@ -51,7 +50,7 @@ def get_bench_params() -> list[bench_params_t]: def unfused_int8_impl( rms_norm_layer: RMSNorm, x: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, quant_dtype: torch.dtype, ): # Norm @@ -68,7 +67,7 @@ def unfused_int8_impl( def unfused_fp8_impl( rms_norm_layer: RMSNorm, x: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, quant_dtype: torch.dtype, ): # Norm @@ -85,7 +84,7 @@ def unfused_fp8_impl( def fused_impl( rms_norm_layer: RMSNorm, # this stores the weights x: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, quant_dtype: torch.dtype, ): out, _ = ops.rms_norm_dynamic_per_token_quant( diff --git a/benchmarks/kernels/bench_mxfp4_qutlass.py b/benchmarks/kernels/bench_mxfp4_qutlass.py new file mode 100644 index 000000000000..dfc7721876a1 --- /dev/null +++ b/benchmarks/kernels/bench_mxfp4_qutlass.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at). +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import copy +import itertools + +import torch +from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix +from weight_shapes import WEIGHT_SHAPES + +from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn +from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked +from vllm.triton_utils import triton + +PROVIDER_CFGS = { + "torch-bf16": dict(enabled=True), + "mxfp4": dict(no_a_quant=False, enabled=True), + "mxfp4-noquant": dict(no_a_quant=True, enabled=True), +} + +_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] + + +def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device): + return ( + deterministic_hadamard_matrix(group_size, dtype=dtype, device=device) + * group_size**-0.5 + ) + + +def _quant_weight_mxfp4( + b: torch.Tensor, forward_hadamard_matrix: torch.Tensor, device: str +): + weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeMx( + b, forward_hadamard_matrix, method="abs_max" + ) + weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton") + return weight_hf_e2m1, weight_hf_scale_block + + +def build_mxfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device): + weight_hf_e2m1, weight_hf_scale_block = _quant_weight_mxfp4( + b, forward_hadamard_matrix, device + ) + alpha = torch.tensor([1.0], device="cuda") + + if cfg["no_a_quant"]: + # Pre-quantize activation + input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx( + a, forward_hadamard_matrix, method="abs_max" + ) + input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton") + + def run(): + return matmul_mxf4_bf16_tn( + input_hf_e2m1, + weight_hf_e2m1, + input_hf_scale_block, + weight_hf_scale_block, + alpha, + ) + + return run + + # Quantize activation on-the-fly + def run(): + input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx( + a, forward_hadamard_matrix, method="abs_max" + ) + input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton") + return matmul_mxf4_bf16_tn( + input_hf_e2m1, + weight_hf_e2m1, + input_hf_scale_block, + weight_hf_scale_block, + alpha, + ) + + return run + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[ + 1, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 24576, + 32768, + ], + x_log=False, + line_arg="provider", + line_vals=_enabled, + line_names=_enabled, + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs MXFP4 GEMMs", + args={}, + ) +) +def benchmark(batch_size, provider, N, K, had_size): + M = batch_size + device = "cuda" + dtype = torch.bfloat16 + + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((N, K), device=device, dtype=dtype) + forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch-bf16": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles + ) + else: + cfg = PROVIDER_CFGS[provider] + run_quant = build_mxfp4_runner( + cfg, a, b, forward_hadamard_matrix, dtype, device + ) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: run_quant(), rep=200, quantiles=quantiles + ) + + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) + + +def prepare_shapes(args): + out = [] + for model, tp_size in itertools.product(args.models, args.tp_sizes): + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_dim] //= tp_size + KN.append(model) + out.append(KN) + return out + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.3-70B-Instruct"], + choices=list(WEIGHT_SHAPES.keys()), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) + args = parser.parse_args() + + for K, N, model in prepare_shapes(args): + for had_size in [32, 64, 128]: + print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs MXFP4 GEMMs TFLOP/s:") + benchmark.run( + print_data=True, + show_plots=True, + save_path=f"bench_mxfp4_res_n{N}_k{K}", + N=N, + K=K, + had_size=had_size, + ) + + print("Benchmark finished!") diff --git a/benchmarks/kernels/bench_nvfp4_qutlass.py b/benchmarks/kernels/bench_nvfp4_qutlass.py new file mode 100644 index 000000000000..6fecc816f946 --- /dev/null +++ b/benchmarks/kernels/bench_nvfp4_qutlass.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at). +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import copy +import itertools + +import torch +from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops # use existing nvfp4 gemm in vllm +from vllm._custom_ops import fusedQuantizeNv +from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked +from vllm.triton_utils import triton + +PROVIDER_CFGS = { + "torch-bf16": dict(enabled=True), + "nvfp4": dict(no_a_quant=False, enabled=True), + "nvfp4-noquant": dict(no_a_quant=True, enabled=True), +} + +_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] + + +def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device): + return ( + deterministic_hadamard_matrix(group_size, dtype=dtype, device=device) + * group_size**-0.5 + ) + + +def _quant_weight_nvfp4( + b: torch.Tensor, + forward_hadamard_matrix: torch.Tensor, + global_scale: torch.Tensor, + device: str, + M: int, + N: int, + K: int, +): + weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeNv( + b, forward_hadamard_matrix, global_scale + ) + weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton").view( + -1, K // 16 + ) + return weight_hf_e2m1, weight_hf_scale_block + + +def build_nvfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K): + alpha = torch.tensor([1.0], device="cuda") + global_scale = torch.tensor([1.0], device="cuda") + weight_hf_e2m1, weight_hf_scale_block = _quant_weight_nvfp4( + b, forward_hadamard_matrix, global_scale, device, M, N, K + ) + + if cfg["no_a_quant"]: + # Pre-quantize activation + input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv( + a, forward_hadamard_matrix, global_scale + ) + input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view( + -1, K // 16 + ) + + def run(): + return ops.cutlass_scaled_fp4_mm( + input_hf_e2m1, + weight_hf_e2m1, + input_hf_scale_block, + weight_hf_scale_block, + alpha, + torch.bfloat16, + ) + + return run + + # Quantize activation on-the-fly + def run(): + input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv( + a, forward_hadamard_matrix, global_scale + ) + input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view( + -1, K // 16 + ) + return ops.cutlass_scaled_fp4_mm( + input_hf_e2m1, + weight_hf_e2m1, + input_hf_scale_block, + weight_hf_scale_block, + alpha, + torch.bfloat16, + ) + + return run + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[ + 1, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 24576, + 32768, + ], + x_log=False, + line_arg="provider", + line_vals=_enabled, + line_names=_enabled, + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs NVFP4 GEMMs", + args={}, + ) +) +def benchmark(batch_size, provider, N, K, had_size): + M = batch_size + device = "cuda" + dtype = torch.bfloat16 + + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((N, K), device=device, dtype=dtype) + forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch-bf16": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles + ) + else: + cfg = PROVIDER_CFGS[provider] + run_quant = build_nvfp4_runner( + cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K + ) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: run_quant(), rep=200, quantiles=quantiles + ) + + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) + + +def prepare_shapes(args): + out = [] + for model, tp_size in itertools.product(args.models, args.tp_sizes): + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_dim] //= tp_size + KN.append(model) + out.append(KN) + return out + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.3-70B-Instruct"], + choices=list(WEIGHT_SHAPES.keys()), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) + args = parser.parse_args() + + for K, N, model in prepare_shapes(args): + for had_size in [16, 32, 64, 128]: + print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs NVFP4 GEMMs TFLOP/s:") + benchmark.run( + print_data=True, + show_plots=True, + save_path=f"bench_nvfp4_res_n{N}_k{K}", + N=N, + K=K, + had_size=had_size, + ) + + print("Benchmark finished!") diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py index e08e5680c191..d33b84fc3601 100644 --- a/benchmarks/kernels/bench_per_token_quant_fp8.py +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools -from typing import Callable +from collections.abc import Callable from unittest.mock import patch import pandas as pd @@ -10,7 +10,8 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.triton_utils import triton -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE def with_triton_mode(fn): diff --git a/benchmarks/kernels/benchmark_activation.py b/benchmarks/kernels/benchmark_activation.py index 93edbcc9391f..7662655b5efa 100644 --- a/benchmarks/kernels/benchmark_activation.py +++ b/benchmarks/kernels/benchmark_activation.py @@ -10,7 +10,8 @@ from vllm.model_executor.custom_op import CustomOp from vllm.platforms import current_platform from vllm.triton_utils import triton -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE batch_size_range = [1, 16, 32, 64, 128] seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] diff --git a/benchmarks/kernels/benchmark_device_communicators.py b/benchmarks/kernels/benchmark_device_communicators.py index 4cbdde5a5b2c..df06a940e6d4 100644 --- a/benchmarks/kernels/benchmark_device_communicators.py +++ b/benchmarks/kernels/benchmark_device_communicators.py @@ -22,8 +22,8 @@ import json import os import time +from collections.abc import Callable from contextlib import nullcontext -from typing import Callable, Optional import torch import torch.distributed as dist @@ -264,12 +264,12 @@ def benchmark_allreduce( def benchmark_allreduce_single( self, sequence_length: int, - allreduce_fn: Callable[[torch.Tensor], Optional[torch.Tensor]], + allreduce_fn: Callable[[torch.Tensor], torch.Tensor | None], should_use_fn: Callable[[torch.Tensor], bool], context, num_warmup: int, num_trials: int, - ) -> Optional[float]: + ) -> float | None: """Benchmark method with CUDA graph optimization.""" try: # Create test tensor (2D: sequence_length x hidden_size) diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py index 69978ec6b23e..bcfa64c3f425 100644 --- a/benchmarks/kernels/benchmark_layernorm.py +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -7,7 +7,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE @torch.inference_mode() diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 799b16999873..39338f338761 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -6,11 +6,12 @@ import json import pickle import time +from collections.abc import Callable from dataclasses import dataclass from enum import Enum, auto from itertools import product from pathlib import Path -from typing import Any, Callable, Optional +from typing import Any import torch import torch.utils.benchmark as TBenchmark @@ -158,7 +159,7 @@ def ref_group_gemm( seq_lens_cpu: torch.Tensor, prompt_lora_mapping_cpu: torch.Tensor, scaling: float, - add_inputs: Optional[bool], + add_inputs: bool | None, ): """ Torch group gemm reference implementation to test correctness of @@ -316,8 +317,8 @@ class BenchmarkContext: lora_rank: int sort_by_lora_id: bool dtype: torch.dtype - seq_length: Optional[int] = None - num_slices: Optional[int] = None # num_slices for slice based ops + seq_length: int | None = None + num_slices: int | None = None # num_slices for slice based ops def with_seq_length(self, seq_length: int) -> "BenchmarkContext": ctx = copy.copy(self) @@ -561,7 +562,7 @@ def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: } def bench_fn_kwargs( - self, op_type: OpType, add_inputs: Optional[bool] = None + self, op_type: OpType, add_inputs: bool | None = None ) -> dict[str, Any]: if op_type.is_shrink_fn(): assert add_inputs is None @@ -575,7 +576,7 @@ def bench_fn_kwargs( raise ValueError(f"Unrecognized optype {self}") def test_correctness( - self, op_type: OpType, expand_fn_add_inputs: Optional[bool] + self, op_type: OpType, expand_fn_add_inputs: bool | None ) -> bool: """ Test correctness of op_type implementation against a grouped gemm @@ -611,8 +612,8 @@ def bench_optype( ctx: BenchmarkContext, arg_pool_size: int, op_type: OpType, - cuda_graph_nops: Optional[int] = None, - expand_fn_add_inputs: Optional[bool] = None, + cuda_graph_nops: int | None = None, + expand_fn_add_inputs: bool | None = None, test_correctness: bool = False, ) -> TMeasurement: assert arg_pool_size >= 1 @@ -679,7 +680,7 @@ def bench_torch_mm( ctx: BenchmarkContext, arg_pool_size: int, op_type: OpType, - cuda_graph_nops: Optional[int] = None, + cuda_graph_nops: int | None = None, ) -> TMeasurement: """ Benchmark basic torch.mm as a roofline. @@ -744,7 +745,7 @@ def use_cuda_graph_recommendation() -> str: """ -def print_timers(timers: list[TMeasurement], args: Optional[argparse.Namespace] = None): +def print_timers(timers: list[TMeasurement], args: argparse.Namespace | None = None): compare = TBenchmark.Compare(timers) compare.print() diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 1b1c3b321cce..e1d5239f5cc9 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -8,10 +8,9 @@ import os import pickle as pkl import time -from collections.abc import Iterable +from collections.abc import Callable, Iterable from dataclasses import dataclass from itertools import product -from typing import Callable, Optional import pandas as pd import torch @@ -63,23 +62,23 @@ class BenchmarkTensors: a: torch.Tensor w_q: torch.Tensor - group_size: Optional[int] + group_size: int | None wtype: ScalarType w_g_s: torch.Tensor - w_g_zp: Optional[torch.Tensor] - w_ch_s: Optional[torch.Tensor] - w_tok_s: Optional[torch.Tensor] + w_g_zp: torch.Tensor | None + w_ch_s: torch.Tensor | None + w_tok_s: torch.Tensor | None @dataclass class TypeConfig: act_type: torch.dtype weight_type: ScalarType - output_type: Optional[torch.dtype] - group_scale_type: Optional[torch.dtype] - group_zero_type: Optional[torch.dtype] - channel_scale_type: Optional[torch.dtype] - token_scale_type: Optional[torch.dtype] + output_type: torch.dtype | None + group_scale_type: torch.dtype | None + group_zero_type: torch.dtype | None + channel_scale_type: torch.dtype | None + token_scale_type: torch.dtype | None def rand_data(shape, dtype=torch.float16, scale=1): @@ -93,8 +92,8 @@ def quantize_and_pack( atype: torch.dtype, w: torch.Tensor, wtype: ScalarType, - stype: Optional[torch.dtype], - group_size: Optional[int], + stype: torch.dtype | None, + group_size: int | None, zero_points: bool = False, ): assert wtype.is_integer(), "TODO: support floating point weights" @@ -113,7 +112,7 @@ def quantize_and_pack( def create_bench_tensors( - shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int] + shape: tuple[int, int, int], types: TypeConfig, group_size: int | None ) -> list[BenchmarkTensors]: m, n, k = shape @@ -331,8 +330,8 @@ def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable]) return res -_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None -_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None +_SWEEP_SCHEDULES_RESULTS: pd.DataFrame | None = None +_SWEEP_SCHEDULES_RESULTS_CSV: str | None = None def bench( diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 02c2db674d4b..9298d3b58dfb 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -579,10 +579,12 @@ def main(args: argparse.Namespace): E = config.ffn_config.moe_num_experts topk = config.ffn_config.moe_top_k intermediate_size = config.ffn_config.ffn_hidden_size + hidden_size = config.hidden_size elif config.architectures[0] == "JambaForCausalLM": E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.intermediate_size + hidden_size = config.hidden_size elif config.architectures[0] in ( "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", @@ -592,6 +594,7 @@ def main(args: argparse.Namespace): E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size + hidden_size = config.hidden_size elif config.architectures[0] in ( "Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM", @@ -600,10 +603,18 @@ def main(args: argparse.Namespace): E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size + hidden_size = config.hidden_size + elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration": + text_config = config.get_text_config() + E = text_config.num_experts + topk = text_config.num_experts_per_tok + intermediate_size = text_config.moe_intermediate_size + hidden_size = text_config.hidden_size elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"): E = config.num_experts topk = config.moe_topk[0] intermediate_size = config.moe_intermediate_size[0] + hidden_size = config.hidden_size else: # Support for llama4 config = config.get_text_config() @@ -611,6 +622,7 @@ def main(args: argparse.Namespace): E = config.num_local_experts topk = config.num_experts_per_tok intermediate_size = config.intermediate_size + hidden_size = config.hidden_size enable_ep = bool(args.enable_expert_parallel) if enable_ep: ensure_divisibility(E, args.tp_size, "Number of experts") @@ -619,8 +631,7 @@ def main(args: argparse.Namespace): else: ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size") shard_intermediate_size = 2 * intermediate_size // args.tp_size - hidden_size = config.hidden_size - dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype + dtype = torch.float16 if current_platform.is_rocm() else config.dtype use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16" block_quant_shape = get_weight_block_size_safety(config) diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py index 04d2205aa372..459eafa6d907 100644 --- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -344,7 +344,7 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok hidden_size = config.hidden_size - dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype + dtype = torch.float16 if current_platform.is_rocm() else config.dtype use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16" use_customized_permute = args.use_customized_permute diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 7e0376c18ecc..1b1e71adeec4 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -3,16 +3,15 @@ import random import time -from typing import Optional import torch from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import ( +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import ( STR_DTYPE_TO_TORCH_DTYPE, - FlexibleArgumentParser, create_kv_caches_with_random, ) @@ -37,7 +36,7 @@ def main( seed: int, do_profile: bool, device: str = "cuda", - kv_cache_dtype: Optional[str] = None, + kv_cache_dtype: str | None = None, ) -> None: current_platform.seed_everything(seed) diff --git a/benchmarks/kernels/benchmark_per_token_group_quant.py b/benchmarks/kernels/benchmark_per_token_group_quant.py index 1ccb5e08b3d5..bdc1eb733084 100644 --- a/benchmarks/kernels/benchmark_per_token_group_quant.py +++ b/benchmarks/kernels/benchmark_per_token_group_quant.py @@ -3,8 +3,8 @@ import argparse import math +from collections.abc import Callable from contextlib import contextmanager -from typing import Callable from unittest.mock import patch import torch diff --git a/benchmarks/kernels/benchmark_polynorm.py b/benchmarks/kernels/benchmark_polynorm.py deleted file mode 100644 index 9ac8f5e6594e..000000000000 --- a/benchmarks/kernels/benchmark_polynorm.py +++ /dev/null @@ -1,155 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import itertools - -import torch - -from vllm import _custom_ops as vllm_ops -from vllm.triton_utils import triton - - -def polynorm_naive( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - eps: float = 1e-6, -): - orig_shape = x.shape - x = x.view(-1, x.shape[-1]) - - def norm(x, eps: float): - return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) - - x = x.float() - return ( - ( - weight[0] * norm(x**3, eps) - + weight[1] * norm(x**2, eps) - + weight[2] * norm(x, eps) - + bias - ) - .to(weight.dtype) - .view(orig_shape) - ) - - -def polynorm_vllm( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - eps: float = 1e-6, -): - orig_shape = x.shape - x = x.view(-1, x.shape[-1]) - - out = torch.empty_like(x) - vllm_ops.poly_norm(out, x, weight, bias, eps) - output = out - - output = output.view(orig_shape) - return output - - -def calculate_diff(batch_size, seq_len, hidden_dim): - dtype = torch.bfloat16 - x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda") - weight = torch.ones(3, dtype=dtype, device="cuda") - bias = torch.ones(1, dtype=dtype, device="cuda") - - output_naive = polynorm_naive(x, weight, bias) - output_vllm = polynorm_vllm(x, weight, bias) - - if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2): - print("✅ All implementations match") - else: - print("❌ Implementations differ") - - -batch_size_range = [2**i for i in range(0, 7, 2)] -seq_length_range = [2**i for i in range(6, 11, 1)] -dim_range = [2048, 4096] -configs = list(itertools.product(dim_range, batch_size_range, seq_length_range)) - - -def get_benchmark(): - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["dim", "batch_size", "seq_len"], - x_vals=[list(_) for _ in configs], - line_arg="provider", - line_vals=["naive", "vllm"], - line_names=["Naive", "vLLM"], - styles=[("blue", "-"), ("red", "-")], - ylabel="us", - plot_name="polynorm-perf", - args={}, - ) - ) - def benchmark(dim, batch_size, seq_len, provider): - dtype = torch.bfloat16 - hidden_dim = dim * 4 - - x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda") - weight = torch.ones(3, dtype=dtype, device="cuda") - bias = torch.ones(1, dtype=dtype, device="cuda") - - quantiles = [0.5, 0.2, 0.8] - - if provider == "naive": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: polynorm_naive(x, weight, bias), - quantiles=quantiles, - ) - else: - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: polynorm_vllm(x, weight, bias), - quantiles=quantiles, - ) - - return 1000 * ms, 1000 * max_ms, 1000 * min_ms - - return benchmark - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--batch-size", - type=int, - default=4, - help="Batch size", - ) - parser.add_argument( - "--seq-len", - type=int, - default=128, - help="Sequence length", - ) - parser.add_argument( - "--hidden-dim", - type=int, - default=8192, - help="Intermediate size of MLP", - ) - parser.add_argument( - "--save-path", - type=str, - default="./configs/polnorm/", - help="Path to save polnorm benchmark results", - ) - - args = parser.parse_args() - - # Run correctness test - calculate_diff( - batch_size=args.batch_size, - seq_len=args.seq_len, - hidden_dim=args.hidden_dim, - ) - - benchmark = get_benchmark() - # Run performance benchmark - benchmark.run(print_data=True, save_path=args.save_path) diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py index 6ab26f5f1adf..61427a77b4e3 100644 --- a/benchmarks/kernels/benchmark_quant.py +++ b/benchmarks/kernels/benchmark_quant.py @@ -7,7 +7,8 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE @torch.inference_mode() diff --git a/benchmarks/kernels/benchmark_reshape_and_cache.py b/benchmarks/kernels/benchmark_reshape_and_cache.py index af9841daadf2..e0ff09d4b397 100644 --- a/benchmarks/kernels/benchmark_reshape_and_cache.py +++ b/benchmarks/kernels/benchmark_reshape_and_cache.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import random import time @@ -11,9 +9,9 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import ( +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import ( STR_DTYPE_TO_TORCH_DTYPE, - FlexibleArgumentParser, create_kv_caches_with_random, ) diff --git a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py index 0aace571064a..29f1b2ccdcf6 100644 --- a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py +++ b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import random import time @@ -14,9 +12,9 @@ ) from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import ( +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import ( STR_DTYPE_TO_TORCH_DTYPE, - FlexibleArgumentParser, create_kv_caches_with_random_flash, ) diff --git a/benchmarks/kernels/benchmark_rmsnorm.py b/benchmarks/kernels/benchmark_rmsnorm.py index 4cf633a81358..d8d7f5bcf9da 100644 --- a/benchmarks/kernels/benchmark_rmsnorm.py +++ b/benchmarks/kernels/benchmark_rmsnorm.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools -from typing import Optional, Union import torch from flashinfer.norm import fused_add_rmsnorm, rmsnorm @@ -21,8 +20,8 @@ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: def forward( self, x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + residual: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: orig_dtype = x.dtype x = x.to(torch.float32) if residual is not None: @@ -41,7 +40,7 @@ def forward( def rmsnorm_naive( x: torch.Tensor, weight: torch.Tensor, - residual: Optional[torch.Tensor] = None, + residual: torch.Tensor | None = None, eps: float = 1e-6, ): naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) @@ -65,7 +64,7 @@ def rmsnorm_naive( def rmsnorm_flashinfer( x: torch.Tensor, weight: torch.Tensor, - residual: Optional[torch.Tensor] = None, + residual: torch.Tensor | None = None, eps: float = 1e-6, ): orig_shape = x.shape @@ -89,7 +88,7 @@ def rmsnorm_flashinfer( def rmsnorm_vllm( x: torch.Tensor, weight: torch.Tensor, - residual: Optional[torch.Tensor] = None, + residual: torch.Tensor | None = None, eps: float = 1e-6, ): orig_shape = x.shape diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index b81baf17a8c6..24869c91a8d7 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from itertools import accumulate -from typing import Optional import nvtx import torch @@ -18,7 +17,7 @@ def benchmark_rope_kernels_multi_lora( seq_len: int, num_heads: int, head_size: int, - rotary_dim: Optional[int], + rotary_dim: int | None, dtype: torch.dtype, seed: int, device: str, diff --git a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py index c7a4066b39d7..a5887aafd30d 100644 --- a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py +++ b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py @@ -1,5 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Comprehensive 3-way SiLU Benchmark Suite + +This benchmark compares three SiLU implementations: +1. SiLU V2 (CUDA) - Optimized CUDA kernel implementation +2. Triton Kernel - Triton-based implementation + +The suite generates detailed performance comparisons including: +- Memory bandwidth utilization +- Speedup ratios (baseline vs optimized implementations) +- Performance across different expert configurations and token distributions +""" + from collections.abc import Callable import matplotlib.pyplot as plt @@ -7,7 +21,7 @@ import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - silu_mul_fp8_quant_deep_gemm_cuda, + persistent_masked_m_silu_mul_quant, ) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -94,6 +108,7 @@ def silu_mul_fp8_quant_deep_gemm_triton( num_parallel_tokens, group_size: int = 128, eps: float = 1e-10, + expert_offsets: torch.Tensor = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales @@ -174,7 +189,7 @@ def silu_mul_fp8_quant_deep_gemm_triton( # Parse generation strategies -strategies = ["uniform", "max_t", "first_t"] +strategies = ["random_imbalanced", "uniform", "max_t"] def benchmark( @@ -195,15 +210,27 @@ def generate_data(seed_offset=0): current_platform.seed_everything(42 + seed_offset) y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous() - if gen_strategy == "uniform": - r = torch.rand(size=(E,), device="cuda") + if gen_strategy == "random_imbalanced": + + def generate_expert_loads(n_e, total_tokens, ratio, device="cuda"): + mean = total_tokens // n_e + min_max = mean // ratio + e = torch.ones(size=(E,), dtype=torch.int64, device=device) * mean + e[0] = min_max + r = torch.rand(size=(E - 1,)) + r /= r.sum() + r *= total_tokens - min_max + r = r.round().long() + e[1:] = r.to(device=device) + return e + + tokens_per_expert = generate_expert_loads(E, total_tokens, 0.7, "cuda") + elif gen_strategy == "uniform": + r = torch.rand(size=(E,)) r /= r.sum() r *= total_tokens - tokens_per_expert = r.int() - tokens_per_expert = torch.minimum( - tokens_per_expert, - torch.ones((E,), device=r.device, dtype=torch.int) * T, - ) + r = r.round().long() + tokens_per_expert = r elif gen_strategy == "max_t": tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda") tokens_per_expert.fill_(total_tokens / E) @@ -281,40 +308,34 @@ def generate_data(seed_offset=0): def create_comparison_plot( - ratio, cuda_times, baseline_times, config_labels, strategy_name, id + ratios, silu_v2_times, triton_times, config_labels, strategy_name, id ): - """Create a comparison plot for a specific generation strategy""" - fig, ax = plt.subplots(1, 1, figsize=(16, 6)) + fig, ax = plt.subplots(1, 1, figsize=(18, 6)) # Configure x-axis positions x = np.arange(len(config_labels)) - width = 0.35 + width = 0.25 # Execution Time plot (lower is better) + ax.bar(x, silu_v2_times, width, label="SiLU V2 (CUDA)", alpha=0.8, color="blue") ax.bar( - x - width / 2, cuda_times, width, label="CUDA Kernel", alpha=0.8, color="blue" - ) - ax.bar( - x + width / 2, - baseline_times, - width, - label="Baseline", - alpha=0.8, - color="orange", + x + width, triton_times, width, label="Triton Kernel", alpha=0.8, color="green" ) - # Add speedup labels over each bar pair + # Add speedup labels over each bar trio for i in range(len(x)): - speedup = ratio[i] - max_height = max(cuda_times[i], baseline_times[i]) + triton_v2_speedup = ratios[i][1] # triton/v2 + max_height = max(silu_v2_times[i], triton_times[i]) + + # Triton/V2 speedup ax.text( - x[i], + x[i] + width / 2, max_height + max_height * 0.02, - f"{speedup:.2f}x", + f"{triton_v2_speedup:.2f}x", ha="center", va="bottom", fontweight="bold", - fontsize=9, + fontsize=8, ) ax.set_xlabel("Configuration") @@ -332,56 +353,75 @@ def create_comparison_plot( def create_combined_plot(all_results): - """Create a combined plot with all strategies in one PNG""" num_strategies = len(all_results) - fig, axes = plt.subplots(num_strategies, 1, figsize=(20, 6 * num_strategies)) + fig, axes = plt.subplots(num_strategies, 1, figsize=(22, 7 * num_strategies)) if num_strategies == 1: axes = [axes] for idx, ( strategy_name, - ratio, - cuda_times, - baseline_times, + all_ratios, + all_silu_v2_results, + all_triton_results, config_labels, + config_x_axis, ) in enumerate(all_results): ax = axes[idx] + # Flatten the nested results to get bandwidth percentages for plotting + silu_v2_bandwidths = [] + triton_bandwidths = [] + flat_ratios = [] + + for config_results in all_silu_v2_results: + for result in config_results: + silu_v2_bandwidths.append(result[3]) # bandwidth percentage + + for config_results in all_triton_results: + for result in config_results: + triton_bandwidths.append(result[3]) # bandwidth percentage + + for config_ratios in all_ratios: + for ratio in config_ratios: + flat_ratios.append(ratio) + # Configure x-axis positions x = np.arange(len(config_labels)) - width = 0.35 + width = 0.25 - # Execution Time plot (lower is better) + # Bandwidth utilization plot (higher is better) ax.bar( - x - width / 2, - cuda_times, + x, + silu_v2_bandwidths, width, - label="CUDA Kernel", + label="SiLU V2 (CUDA)", alpha=0.8, color="blue", ) ax.bar( - x + width / 2, - baseline_times, + x + width, + triton_bandwidths, width, - label="Baseline", + label="Triton Kernel", alpha=0.8, - color="orange", + color="green", ) - # Add speedup labels over each bar pair + # Add speedup labels over each bar trio for i in range(len(x)): - speedup = ratio[i] - max_height = max(cuda_times[i], baseline_times[i]) + triton_v2_speedup = flat_ratios[i] # triton/v2 + max_height = max(silu_v2_bandwidths[i], triton_bandwidths[i]) + + # Triton/V2 speedup ax.text( - x[i], + x[i] + width / 2, max_height + max_height * 0.02, - f"{speedup:.2f}x", + f"{triton_v2_speedup:.2f}x", ha="center", va="bottom", fontweight="bold", - fontsize=9, + fontsize=8, ) ax.set_xlabel("Configuration") @@ -395,7 +435,7 @@ def create_combined_plot(all_results): ax.grid(True, alpha=0.3) plt.tight_layout() - filename = "../../silu_bench/silu_benchmark_combined.png" + filename = "silu_benchmark_combined_3way.png" plt.savefig(filename, dpi=300, bbox_inches="tight") plt.show() @@ -405,7 +445,9 @@ def create_combined_plot(all_results): outer_dim = 7168 configs = [ # DeepSeekV3 Configs + # (1, 56, 7168), (8, 1024, 7168), + # (32, 56, 7168), # DeepSeekV3 Configs (32, 1024, 7168), # DeepSeekV3 Configs @@ -417,6 +459,7 @@ def create_combined_plot(all_results): strategy_descriptions = { "uniform": "Uniform Random", + "random_imbalanced": "Imbalanced Random", "max_t": "Even Assignment", "first_t": "experts[0] = T, experts[1:] = 0", } @@ -433,28 +476,31 @@ def create_combined_plot(all_results): print(f"Testing strategy: {strategy_descriptions[strategy]}") print(f"{'=' * 60}") - # Collect benchmark data for both algorithms + # Collect benchmark data for all three algorithms config_labels = [] config_x_axis = [] - all_cuda_results = [] - all_baseline_results = [] + all_silu_v2_results = [] + all_triton_results = [] all_ratios = [] for E, T, H in configs: - total_tokens_config = [8 * E, 16 * E, 32 * E, 64 * E, 128 * E, 256 * E] + total_tokens_config = [] + for i in [8, 16, 32, 64, 128, 256, 512]: + if i <= T: + total_tokens_config.append(i * E) config_x_axis.append(total_tokens_config) - cuda_results = [] - baseline_results = [] + silu_v2_results = [] + triton_results = [] ratios = [] for total_tokens in total_tokens_config: config_label = f"E={E},T={T},H={H},TT={total_tokens}" config_labels.append(config_label) - # CUDA kernel results - time_ms_cuda, gflops, gbps, perc = benchmark( - silu_mul_fp8_quant_deep_gemm_cuda, + # SiLU V2 (CUDA kernel) results + time_ms_silu_v2, gflops, gbps, perc = benchmark( + persistent_masked_m_silu_mul_quant, E, T, H, @@ -463,9 +509,9 @@ def create_combined_plot(all_results): num_warmups=num_warmups, gen_strategy=strategy, ) - cuda_results.append((time_ms_cuda, gflops, gbps, perc)) + silu_v2_results.append((time_ms_silu_v2, gflops, gbps, perc)) - # Baseline results + # Triton kernel results time_ms_triton, gflops, gbps, perc = benchmark( silu_mul_fp8_quant_deep_gemm_triton, E, @@ -476,12 +522,20 @@ def create_combined_plot(all_results): num_warmups=num_warmups, gen_strategy=strategy, ) - baseline_results.append((time_ms_triton, gflops, gbps, perc)) - ratios.append(time_ms_triton / time_ms_cuda) + triton_results.append((time_ms_triton, gflops, gbps, perc)) - print(f"Completed: {config_label}") - all_cuda_results.append(cuda_results) - all_baseline_results.append(baseline_results) + # Calculate speedup ratios (triton baseline / implementation) + triton_v2_ratio = time_ms_triton / time_ms_silu_v2 + ratios.append(triton_v2_ratio) + + print( + f"Completed: {config_label}:" + f" V2: {time_ms_silu_v2:.3f}ms," + f" Triton: {time_ms_triton:.3f}ms" + ) + + all_silu_v2_results.append(silu_v2_results) + all_triton_results.append(triton_results) all_ratios.append(ratios) # Store results for combined plotting @@ -489,8 +543,8 @@ def create_combined_plot(all_results): ( strategy_descriptions[strategy], all_ratios, - all_cuda_results, - all_baseline_results, + all_silu_v2_results, + all_triton_results, config_labels, config_x_axis, ) @@ -498,15 +552,18 @@ def create_combined_plot(all_results): # Print summary table for this strategy print(f"\nSummary Table - {strategy_descriptions[strategy]}:") - print(f"{'Config':<20} {'CUDA Time(ms)':<12} {'Base Time(ms)':<12} {'Speedup':<8}") - print("-" * 60) + print(f" {'V2 Time(ms)':<12} {'Triton Time(ms)':<14} {'Triton/V2':<10}") + print("-" * 90) for i, (E, T, H) in enumerate(configs): - speedup = baseline_results[i][0] / cuda_results[i][0] + # Get the first result for each config (simplifying for summary) + v2_time = silu_v2_results[i][0] + triton_time = triton_results[i][0] + triton_v2_speedup = triton_time / v2_time config_label = f"E={E:3d},T={T:4d},H={H:4d}" print( - f"{config_label:<20} {cuda_results[i][0]:8.5f} " - f"{baseline_results[i][0]:8.5f} {speedup:6.2f}x" + f"{config_label:<20} {v2_time:8.5f} {triton_time:10.5f} " + f"{triton_v2_speedup:8.2f}x" ) @@ -514,15 +571,14 @@ def create_total_tokens_plot(all_results): num_strategies = len(all_results) num_configs = len(configs) - # Create side-by-side subplots: 2 columns for speedup and bandwidth percentage fig, axs = plt.subplots( - num_strategies, num_configs * 2, figsize=(28, 6 * num_strategies) + num_strategies, num_configs * 2, figsize=(32, 8 * num_strategies) ) # Add main title to the entire figure fig.suptitle( - "Performance Analysis: Speedup vs Bandwidth Utilization (Triton & CUDA)", - fontsize=16, + "Performance Analysis: Speedup vs Bandwidth Utilization (SiLU V2, and Triton)", + fontsize=18, fontweight="bold", y=0.98, ) @@ -539,8 +595,8 @@ def create_total_tokens_plot(all_results): ( strategy_name, all_ratios, - all_cuda_results, - all_baseline_results, + all_silu_v2_results, + all_triton_results, config_labels, config_x_axis, ) = result @@ -555,42 +611,54 @@ def create_total_tokens_plot(all_results): ratios = all_ratios[config_idx] total_tokens_values = config_x_axis[config_idx] - # Extract CUDA and Triton bandwidth percentages - cuda_bandwidth_percentages = [ - result[3] for result in all_cuda_results[config_idx] + # Extract speedup ratios + triton_v2_ratios = [ratio for ratio in ratios] + + # Extract bandwidth percentages for all implementations + v2_bandwidth_percentages = [ + result[3] for result in all_silu_v2_results[config_idx] ] triton_bandwidth_percentages = [ - result[3] for result in all_baseline_results[config_idx] + result[3] for result in all_triton_results[config_idx] ] # Plot speedup ratios vs total tokens (left plot) ax_speedup.plot( - total_tokens_values, ratios, "bo-", linewidth=3, markersize=8 + total_tokens_values, + triton_v2_ratios, + "go-", + linewidth=3, + markersize=8, + label="Triton/V2 Speedup", ) ax_speedup.set_title( - f"{strategy_name}\nSpeedup (CUDA/Triton)\nE={E}, T={T}, H={H}", + f"{strategy_name}\nSpeedup vs Baseline (Triton)\nE={E}, T={T}, H={H}", fontsize=12, fontweight="bold", ) ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11) ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11) + ax_speedup.legend(prop={"weight": "bold"}) ax_speedup.grid(True, alpha=0.3) + # Plot bandwidth utilization (right plot) ax_bandwidth.plot( total_tokens_values, - cuda_bandwidth_percentages, - "ro-", + v2_bandwidth_percentages, + "o-", linewidth=3, markersize=8, - label="CUDA", + label="SiLU V2", + color="blue", ) ax_bandwidth.plot( total_tokens_values, triton_bandwidth_percentages, - "go-", + "o-", linewidth=3, markersize=8, label="Triton", + color="green", ) ax_bandwidth.set_title( f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}", @@ -618,38 +686,12 @@ def create_total_tokens_plot(all_results): for label in ax.get_xticklabels() + ax.get_yticklabels(): label.set_fontweight("bold") - # Add value labels on speedup points - for x, y in zip(total_tokens_values, ratios): + # Add value labels on Triton/V2 speedup points + for x, y in zip(total_tokens_values, triton_v2_ratios): ax_speedup.annotate( f"{y:.2f}x", (x, y), textcoords="offset points", - xytext=(0, 12), - ha="center", - fontsize=10, - fontweight="bold", - bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7), - ) - - # Add value labels on CUDA bandwidth points - for x, y in zip(total_tokens_values, cuda_bandwidth_percentages): - ax_bandwidth.annotate( - f"{y:.1f}%", - (x, y), - textcoords="offset points", - xytext=(0, 12), - ha="center", - fontsize=9, - fontweight="bold", - bbox=dict(boxstyle="round,pad=0.2", facecolor="red", alpha=0.3), - ) - - # Add value labels on Triton bandwidth points - for x, y in zip(total_tokens_values, triton_bandwidth_percentages): - ax_bandwidth.annotate( - f"{y:.1f}%", - (x, y), - textcoords="offset points", xytext=(0, -15), ha="center", fontsize=9, @@ -659,17 +701,20 @@ def create_total_tokens_plot(all_results): plt.tight_layout() plt.subplots_adjust(top=0.93) # Make room for main title - filename = "silu_benchmark_total_tokens.png" + filename = "silu_benchmark_total_tokens_3way.png" plt.savefig(filename, dpi=300, bbox_inches="tight") plt.show() return filename -# Create combined plot with all strategies -combined_plot_filename = create_total_tokens_plot(all_results) +# Create comprehensive 3-way comparison plots +combined_plot_filename = create_combined_plot(all_results) +total_tokens_plot_filename = create_total_tokens_plot(all_results) -print(f"\n{'=' * 60}") -print("Benchmark Complete!") -print(f"Generated combined plot: {combined_plot_filename}") -print(f"{'=' * 60}") +print(f"\n{'=' * 80}") +print("3-Way Benchmark Suite Complete!") +print(f"Generated combined comparison plot: {combined_plot_filename}") +print(f"Generated total tokens analysis plot: {total_tokens_plot_filename}") +print("Compared: SiLU V2 (CUDA), and Triton implementations") +print(f"{'=' * 80}") diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py index 6ddab4621457..f7cdc25794ca 100644 --- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -4,7 +4,6 @@ import csv import os from datetime import datetime -from typing import Optional import flashinfer import torch @@ -28,9 +27,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn): @torch.no_grad() def benchmark_decode( dtype: torch.dtype, - quant_dtypes: tuple[ - Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] - ], + quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None], batch_size: int, max_seq_len: int, num_heads: tuple[int, int] = (64, 8), diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py index 131df74c7de1..7993354475fc 100644 --- a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py @@ -4,7 +4,6 @@ import csv import os from datetime import datetime -from typing import Optional import flashinfer import torch @@ -28,9 +27,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn): @torch.no_grad() def benchmark_prefill( dtype: torch.dtype, - quant_dtypes: tuple[ - Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] - ], + quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None], batch_size: int, max_seq_len: int, num_heads: tuple[int, int] = (64, 8), diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py index c6c8e0b0b936..602fad181074 100644 --- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py +++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py @@ -14,7 +14,7 @@ from tqdm import tqdm from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - _w8a8_block_fp8_matmul, + _w8a8_triton_block_scaled_mm, ) from vllm.platforms import current_platform from vllm.triton_utils import triton @@ -83,7 +83,7 @@ def grid(META): ) if A.dtype == torch.float8_e4m3fn: - kernel = _w8a8_block_fp8_matmul + kernel = _w8a8_triton_block_scaled_mm else: raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.") diff --git a/benchmarks/kernels/utils.py b/benchmarks/kernels/utils.py index 4bbb36bb4359..a9af811bbe9c 100644 --- a/benchmarks/kernels/utils.py +++ b/benchmarks/kernels/utils.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses -from collections.abc import Iterable -from typing import Any, Callable, Optional +from collections.abc import Callable, Iterable +from typing import Any import torch import torch.utils.benchmark as TBenchmark @@ -55,7 +55,7 @@ def n_args(self): def __init__( self, - cuda_graph_params: Optional[CudaGraphBenchParams], + cuda_graph_params: CudaGraphBenchParams | None, label: str, sub_label: str, description: str, diff --git a/benchmarks/multi_turn/bench_dataset.py b/benchmarks/multi_turn/bench_dataset.py index 67b937930d58..2674899d1cc5 100644 --- a/benchmarks/multi_turn/bench_dataset.py +++ b/benchmarks/multi_turn/bench_dataset.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod from statistics import mean -from typing import Any, NamedTuple, Optional, Union +from typing import Any, NamedTuple import numpy as np # type: ignore import pandas as pd # type: ignore @@ -35,8 +35,8 @@ def sample(self, size: int = 1) -> np.ndarray: class UniformDistribution(Distribution): def __init__( self, - min_val: Union[int, float], - max_val: Union[int, float], + min_val: int | float, + max_val: int | float, is_integer: bool = True, ) -> None: self.min_val = min_val @@ -56,7 +56,7 @@ def __repr__(self) -> str: class ConstantDistribution(Distribution): - def __init__(self, value: Union[int, float]) -> None: + def __init__(self, value: int | float) -> None: self.value = value self.max_val = value @@ -68,7 +68,7 @@ def __repr__(self) -> str: class ZipfDistribution(Distribution): - def __init__(self, alpha: float, max_val: Optional[int] = None) -> None: + def __init__(self, alpha: float, max_val: int | None = None) -> None: self.alpha = alpha self.max_val = max_val @@ -83,7 +83,7 @@ def __repr__(self) -> str: class PoissonDistribution(Distribution): - def __init__(self, alpha: float, max_val: Optional[int] = None) -> None: + def __init__(self, alpha: float, max_val: int | None = None) -> None: self.alpha = alpha self.max_val = max_val @@ -100,11 +100,11 @@ def __repr__(self) -> str: class LognormalDistribution(Distribution): def __init__( self, - mean: Optional[float] = None, - sigma: Optional[float] = None, - average: Optional[int] = None, - median_ratio: Optional[float] = None, - max_val: Optional[int] = None, + mean: float | None = None, + sigma: float | None = None, + average: int | None = None, + median_ratio: float | None = None, + max_val: int | None = None, ) -> None: self.average = average self.median_ratio = median_ratio diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py index 66d85eaf5131..67a085b40ed3 100644 --- a/benchmarks/multi_turn/benchmark_serving_multi_turn.py +++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py @@ -13,7 +13,7 @@ from enum import Enum from http import HTTPStatus from statistics import mean -from typing import NamedTuple, Optional, Union +from typing import NamedTuple import aiohttp # type: ignore import numpy as np # type: ignore @@ -46,9 +46,9 @@ def __str__(self): class ClientArgs(NamedTuple): seed: int - max_num_requests: Optional[int] + max_num_requests: int | None skip_first_turn: bool - max_turns: Optional[int] + max_turns: int | None max_active_conversations: int verbose: bool print_content: bool @@ -109,9 +109,9 @@ def __str__(self) -> str: class MetricStats: def __init__(self) -> None: - self.min: Optional[float] = None - self.max: Optional[float] = None - self.avg: Optional[float] = None + self.min: float | None = None + self.max: float | None = None + self.avg: float | None = None self.sum = 0.0 self.count = 0 @@ -143,7 +143,7 @@ def __init__(self, window_size: int) -> None: self.index = 0 self.sum = 0.0 self.count = 0 - self.avg: Optional[float] = None + self.avg: float | None = None def update(self, new_value: float) -> None: if self.count < self.window_size: @@ -169,7 +169,7 @@ def __repr__(self) -> str: class DebugStats: def __init__(self, logger: logging.Logger, window_size: int) -> None: self.logger = logger - self.metrics: dict[str, Union[MovingAverage, MetricStats]] = { + self.metrics: dict[str, MovingAverage | MetricStats] = { "moving_avg_ttft_ms": MovingAverage(window_size), "moving_avg_tpot_ms": MovingAverage(window_size), "ttft_ms": MetricStats(), @@ -198,14 +198,6 @@ def print(self) -> None: self.logger.info("-" * 50) -# Must support Python 3.8, we can't use str.removeprefix(prefix) -# introduced in Python 3.9 -def remove_prefix(text: str, prefix: str) -> str: - if text.startswith(prefix): - return text[len(prefix) :] - return text - - def nanosec_to_millisec(value: float) -> float: return value / 1000000.0 @@ -220,8 +212,8 @@ async def send_request( chat_url: str, model: str, stream: bool = True, - min_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, + min_tokens: int | None = None, + max_tokens: int | None = None, ) -> ServerResponse: payload = { "model": model, @@ -250,9 +242,9 @@ async def send_request( timeout = aiohttp.ClientTimeout(total=timeout_sec) valid_response = True - ttft: Optional[float] = None + ttft: float | None = None chunk_delay: list[int] = [] - latency: Optional[float] = None + latency: float | None = None first_chunk = "" generated_text = "" @@ -269,7 +261,7 @@ async def send_request( if not chunk_bytes: continue - chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") if chunk == "[DONE]": # End of stream latency = time.perf_counter_ns() - start_time @@ -364,7 +356,7 @@ async def send_turn( req_args: RequestArgs, verbose: bool, verify_output: bool, -) -> Optional[RequestStats]: +) -> RequestStats | None: assert messages_to_use > 0 assert messages_to_use <= len(conversation_messages) @@ -644,7 +636,7 @@ async def client_main( if args.verbose: curr_time_sec: float = time.perf_counter() - time_since_last_turn: Union[str, float] = "N/A" + time_since_last_turn: str | float = "N/A" if conv_id in time_of_last_turn: time_since_last_turn = round( curr_time_sec - time_of_last_turn[conv_id], 3 @@ -769,7 +761,7 @@ def get_client_config( "Number of conversations must be equal or larger than the number of clients" ) - max_req_per_client: Optional[int] = None + max_req_per_client: int | None = None if args.max_num_requests is not None: # Max number of requests per client req_per_client = args.max_num_requests // args.num_clients @@ -936,13 +928,13 @@ async def main_mp( f"{num_clients_finished} out of {bench_args.num_clients} clients finished, collected {len(client_metrics)} measurements, runtime {runtime_sec:.3f} sec{Color.RESET}" # noqa: E501 ) - rps: Union[str, float] = round(len(client_metrics) / runtime_sec, 3) + rps: str | float = round(len(client_metrics) / runtime_sec, 3) if len(client_metrics) < (5 * bench_args.num_clients): # Do not estimate the RPS if the number of samples is very low # (threshold can be tuned if needed) rps = "N/A" - runtime_left_sec: Union[str, float] = round( + runtime_left_sec: str | float = round( (runtime_sec / finished_convs) * (total_convs - finished_convs), 3 ) if percent < 0.05: @@ -1032,7 +1024,7 @@ def process_statistics( warmup_percentages: list[float], test_params: dict, verbose: bool, - gen_conv_args: Optional[GenConvArgs] = None, + gen_conv_args: GenConvArgs | None = None, excel_output: bool = False, ) -> None: if len(client_metrics) == 0: @@ -1259,7 +1251,7 @@ async def main() -> None: default=None, help="The model name used in the API. " "If not specified, the model name will be the " - "same as the ``--model`` argument. ", + "same as the `--model` argument. ", ) parser.add_argument( diff --git a/benchmarks/multi_turn/convert_sharegpt_to_openai.py b/benchmarks/multi_turn/convert_sharegpt_to_openai.py index c3622c99a2e5..fccab4d0ce21 100644 --- a/benchmarks/multi_turn/convert_sharegpt_to_openai.py +++ b/benchmarks/multi_turn/convert_sharegpt_to_openai.py @@ -13,7 +13,7 @@ import json import random from statistics import mean -from typing import Any, Optional +from typing import Any import pandas as pd # type: ignore import tqdm # type: ignore @@ -25,7 +25,7 @@ def has_non_english_chars(text: str) -> bool: def content_is_valid( - content: str, min_content_len: Optional[int], max_content_len: Optional[int] + content: str, min_content_len: int | None, max_content_len: int | None ) -> bool: if min_content_len and len(content) < min_content_len: return False @@ -37,7 +37,7 @@ def content_is_valid( def print_stats( - conversations: "list[dict[Any, Any]]", tokenizer: Optional[AutoTokenizer] = None + conversations: "list[dict[Any, Any]]", tokenizer: AutoTokenizer | None = None ) -> None: # Collect statistics stats = [] @@ -109,12 +109,12 @@ def convert_sharegpt_to_openai( seed: int, input_file: str, output_file: str, - max_items: Optional[int], - min_content_len: Optional[int] = None, - max_content_len: Optional[int] = None, - min_turns: Optional[int] = None, - max_turns: Optional[int] = None, - model: Optional[str] = None, + max_items: int | None, + min_content_len: int | None = None, + max_content_len: int | None = None, + min_turns: int | None = None, + max_turns: int | None = None, + model: str | None = None, ) -> None: if min_turns and max_turns: assert min_turns <= max_turns diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index c962564c8da0..4b8f0daacb00 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -188,34 +188,66 @@ else() message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.") endif() -# -# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms) -# Flag to enable ACL kernels for AARCH64 platforms -if (VLLM_BUILD_ACL STREQUAL "ON") - set(USE_ACL ON) -else() - set(USE_ACL OFF) -endif() +# Build oneDNN for GEMM kernels (only for x86-AVX512 /ARM platforms) if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) - FetchContent_Declare( - oneDNN - GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git - GIT_TAG v3.9 - GIT_PROGRESS TRUE - GIT_SHALLOW TRUE - ) + # Fetch and build Arm Compute Library (ACL) as oneDNN's backend for AArch64 + # TODO [fadara01]: remove this once ACL can be fetched and built automatically as a dependency of oneDNN + if(ASIMD_FOUND) + if(DEFINED ENV{ACL_ROOT_DIR} AND IS_DIRECTORY "$ENV{ACL_ROOT_DIR}") + message(STATUS "Using ACL from specified source directory: $ENV{ACL_ROOT_DIR}") + else() + message(STATUS "Downloading Arm Compute Library (ACL) from GitHub") + FetchContent_Populate(arm_compute + SUBBUILD_DIR "${FETCHCONTENT_BASE_DIR}/arm_compute-subbuild" + SOURCE_DIR "${FETCHCONTENT_BASE_DIR}/arm_compute-src" + GIT_REPOSITORY https://github.com/ARM-software/ComputeLibrary.git + GIT_TAG v52.2.0 + GIT_SHALLOW TRUE + GIT_PROGRESS TRUE + ) + set(ENV{ACL_ROOT_DIR} "${arm_compute_SOURCE_DIR}") + endif() - if(USE_ACL) - find_library(ARM_COMPUTE_LIBRARY NAMES arm_compute PATHS $ENV{ACL_ROOT_DIR}/build/) - if(NOT ARM_COMPUTE_LIBRARY) - message(FATAL_ERROR "Could not find ARM Compute Library: please set ACL_ROOT_DIR") + # Build ACL with scons + include(ProcessorCount) + ProcessorCount(_NPROC) + execute_process( + COMMAND scons -j${_NPROC} + Werror=0 debug=0 neon=1 examples=0 embed_kernels=0 os=linux + arch=armv8.2-a build=native benchmark_examples=0 fixed_format_kernels=1 + multi_isa=1 openmp=1 cppthreads=0 + WORKING_DIRECTORY "$ENV{ACL_ROOT_DIR}" + RESULT_VARIABLE _acl_rc + ) + if(NOT _acl_rc EQUAL 0) + message(FATAL_ERROR "ACL SCons build failed (exit ${_acl_rc}).") endif() + set(ONEDNN_AARCH64_USE_ACL "ON") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/") add_compile_definitions(VLLM_USE_ACL) endif() + set(FETCHCONTENT_SOURCE_DIR_ONEDNN "$ENV{FETCHCONTENT_SOURCE_DIR_ONEDNN}" CACHE PATH "Path to a local oneDNN source directory.") + + if(FETCHCONTENT_SOURCE_DIR_ONEDNN) + message(STATUS "Using oneDNN from specified source directory: ${FETCHCONTENT_SOURCE_DIR_ONEDNN}") + FetchContent_Declare( + oneDNN + SOURCE_DIR ${FETCHCONTENT_SOURCE_DIR_ONEDNN} + ) + else() + message(STATUS "Downloading oneDNN from GitHub") + FetchContent_Declare( + oneDNN + GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git + GIT_TAG v3.9 + GIT_PROGRESS TRUE + GIT_SHALLOW TRUE + ) + endif() + set(ONEDNN_LIBRARY_TYPE "STATIC") set(ONEDNN_BUILD_DOC "OFF") set(ONEDNN_BUILD_EXAMPLES "OFF") @@ -227,7 +259,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON set(ONEDNN_ENABLE_ITT_TASKS "OFF") set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") - set(ONEDNN_VERBOSE "ON") + set(ONEDNN_VERBOSE "OFF") set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) FetchContent_MakeAvailable(oneDNN) @@ -309,4 +341,4 @@ define_gpu_extension_target( WITH_SOABI ) -message(STATUS "Enabling C extension.") +message(STATUS "Enabling C extension.") \ No newline at end of file diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index c9e7aec880b9..f661084ec48a 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -19,7 +19,7 @@ else() FetchContent_Declare( flashmla GIT_REPOSITORY https://github.com/vllm-project/FlashMLA - GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f + GIT_TAG 46d64a8ebef03fa50b4ae74937276a5c940e3f95 GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" @@ -66,6 +66,7 @@ if(FLASH_MLA_ARCHS) ${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu + ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu ) set(FlashMLA_INCLUDES diff --git a/cmake/external_projects/qutlass.cmake b/cmake/external_projects/qutlass.cmake new file mode 100644 index 000000000000..5a59a409999a --- /dev/null +++ b/cmake/external_projects/qutlass.cmake @@ -0,0 +1,97 @@ +include(FetchContent) + +set(CUTLASS_INCLUDE_DIR "${CUTLASS_INCLUDE_DIR}" CACHE PATH "Path to CUTLASS include/ directory") + +if(DEFINED ENV{QUTLASS_SRC_DIR}) + set(QUTLASS_SRC_DIR $ENV{QUTLASS_SRC_DIR}) +endif() + +if(QUTLASS_SRC_DIR) + FetchContent_Declare( + qutlass + SOURCE_DIR ${QUTLASS_SRC_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + ) +else() + FetchContent_Declare( + qutlass + GIT_REPOSITORY https://github.com/IST-DASLab/qutlass.git + GIT_TAG 830d2c4537c7396e14a02a46fbddd18b5d107c65 + GIT_PROGRESS TRUE + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + ) +endif() + +FetchContent_Populate(qutlass) + +if(NOT qutlass_SOURCE_DIR) + message(FATAL_ERROR "[QUTLASS] source directory could not be resolved.") +endif() +message(STATUS "[QUTLASS] QuTLASS is available at ${qutlass_SOURCE_DIR}") + +cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0a" "${CUDA_ARCHS}") +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND QUTLASS_ARCHS) + + if(QUTLASS_ARCHS MATCHES "10\\.0a") + set(QUTLASS_TARGET_CC 100) + elseif(QUTLASS_ARCHS MATCHES "12\\.0a") + set(QUTLASS_TARGET_CC 120) + else() + message(FATAL_ERROR "[QUTLASS] internal error parsing CUDA_ARCHS='${QUTLASS_ARCHS}'.") + endif() + + set(QUTLASS_SOURCES + ${qutlass_SOURCE_DIR}/qutlass/csrc/bindings.cpp + ${qutlass_SOURCE_DIR}/qutlass/csrc/gemm.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/gemm_ada.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx_sm100.cu + ${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv_sm100.cu + ) + + set(QUTLASS_INCLUDES + ${qutlass_SOURCE_DIR} + ${qutlass_SOURCE_DIR}/qutlass + ${qutlass_SOURCE_DIR}/qutlass/csrc/include + ${qutlass_SOURCE_DIR}/qutlass/csrc/include/cutlass_extensions + ) + + if(CUTLASS_INCLUDE_DIR AND EXISTS "${CUTLASS_INCLUDE_DIR}/cutlass/cutlass.h") + list(APPEND QUTLASS_INCLUDES "${CUTLASS_INCLUDE_DIR}") + elseif(EXISTS "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include/cutlass/cutlass.h") + list(APPEND QUTLASS_INCLUDES "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include") + message(STATUS "[QUTLASS] Using QuTLASS vendored CUTLASS headers (no vLLM CUTLASS detected).") + else() + message(FATAL_ERROR "[QUTLASS] CUTLASS headers not found. " + "Set -DCUTLASS_INCLUDE_DIR=/path/to/cutlass/include") + endif() + + set_gencode_flags_for_srcs( + SRCS "${QUTLASS_SOURCES}" + CUDA_ARCHS "${QUTLASS_ARCHS}" + ) + + target_sources(_C PRIVATE ${QUTLASS_SOURCES}) + target_include_directories(_C PRIVATE ${QUTLASS_INCLUDES}) + target_compile_definitions(_C PRIVATE + QUTLASS_DISABLE_PYBIND=1 + TARGET_CUDA_ARCH=${QUTLASS_TARGET_CC} + ) + + set_property(SOURCE ${QUTLASS_SOURCES} APPEND PROPERTY COMPILE_OPTIONS + $<$:--expt-relaxed-constexpr --use_fast_math -O3> + ) + +else() + if("${CMAKE_CUDA_COMPILER_VERSION}" VERSION_LESS "12.8") + message(STATUS + "[QUTLASS] Skipping build: CUDA 12.8 or newer is required (found ${CMAKE_CUDA_COMPILER_VERSION}).") + else() + message(STATUS + "[QUTLASS] Skipping build: no supported arch (12.0a / 10.0a) found in " + "CUDA_ARCHS='${CUDA_ARCHS}'.") + endif() +endif() diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index e6686275cabb..931090db50e9 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 4695e6bed5366c41e28c06cd86170166e4f43d00 + GIT_TAG a893712401d70362fbb299cd9c4b3476e8e9ed54 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 000000000000..304c0be8105f --- /dev/null +++ b/codecov.yml @@ -0,0 +1,12 @@ +codecov: + require_ci_to_pass: false + +fixes: + # Map source code paths to repository root paths + # Wildcards match any Python version (python3.*) + - "/vllm-workspace/src/vllm/::vllm/" + - "/vllm-workspace/vllm/::vllm/" + - "/usr/local/lib/python3.*/dist-packages/vllm/::vllm/" + - "/usr/local/lib/python3.*/site-packages/vllm/::vllm/" + - "/usr/lib/python3.*/dist-packages/vllm/::vllm/" + - "/usr/lib/python3.*/site-packages/vllm/::vllm/" diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 57382c1ddc65..052ff168cec4 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -28,10 +28,10 @@ #ifdef USE_ROCM #include - #include "../quantization/fp8/amd/quant_utils.cuh" + #include "../quantization/w8a8/fp8/amd/quant_utils.cuh" typedef __hip_bfloat16 __nv_bfloat16; #else - #include "../quantization/fp8/nvidia/quant_utils.cuh" + #include "../quantization/w8a8/fp8/nvidia/quant_utils.cuh" #endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) diff --git a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp index 297d94dcc063..2d4b4a67d242 100644 --- a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp +++ b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp @@ -125,32 +125,37 @@ class MLA { } static void set_split_kv (KernelArguments& args) { - // printf("set_split_kv start"); if (args.split_kv >= 1) return; auto [H, K, D, B] = args.problem_shape; - // std::cout << H << " " << K << " " << D << " " << B << "\n"; int sm_count = args.hw_info.sm_count; - // printf(" sm_count = %d\n", sm_count); - int max_splits = ceil_div(K, 128); - max_splits = min(16, max_splits); - - // TODO: This avoids a hang when the batch size larger than 1 and - // there is more than 1 kv_splits. - // Discuss with NVIDIA how this can be fixed. - if (B > 1) { - max_splits = min(1, max_splits); + float seq_length_k = static_cast(K) / 1024.0f; + int max_splits = 1; + + if (B <= 4 && seq_length_k >= 16) { + max_splits = 16; + } + else if (B <= 8 && seq_length_k >= 4) { + max_splits = 8; + } + else if ((B <= 16 && seq_length_k >= 8) || + (B == 48 && seq_length_k >= 32)) { + max_splits = 4; + } + else if ((B <= 32 && seq_length_k >= 16) || + (B == 96 && seq_length_k >= 16)) { + max_splits = 2; } - - // printf(" max_splits = %d\n", max_splits); + else { + max_splits = 1; + } + + // Wave-aware scheduling: ensure integer number of waves in K dimension int sms_per_batch = max(1, sm_count / B); - // printf(" sms_per_batch = %d\n", sms_per_batch); int split_heur = min(max_splits, sms_per_batch); int waves = ceil_div(B * split_heur, sm_count); int k_waves = ceil_div(max_splits, split_heur); int split_wave_aware = ceil_div(max_splits, k_waves); args.split_kv = split_wave_aware; - // printf(" args.split_kv = %d\n", args.split_kv); - } /// Determines whether the GEMM can execute the given problem. diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index f4b116c94f19..0aa0dc14c748 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -9,9 +9,9 @@ #include "quantization/vectorization_utils.cuh" #ifdef USE_ROCM - #include "quantization/fp8/amd/quant_utils.cuh" + #include "quantization/w8a8/fp8/amd/quant_utils.cuh" #else - #include "quantization/fp8/nvidia/quant_utils.cuh" + #include "quantization/w8a8/fp8/nvidia/quant_utils.cuh" #endif #include diff --git a/csrc/core/batch_invariant.hpp b/csrc/core/batch_invariant.hpp index 19e422e4b80c..fffe96b86857 100644 --- a/csrc/core/batch_invariant.hpp +++ b/csrc/core/batch_invariant.hpp @@ -5,12 +5,15 @@ namespace vllm { -// vllm_kernel_override_batch_invariant(); returns true -// if env VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1 -inline bool vllm_kernel_override_batch_invariant() { - std::string env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"; - const char* val = std::getenv(env_key.c_str()); - return (val && std::atoi(val) != 0) ? 1 : 0; +// vllm_is_batch_invariant(); returns true +// if env VLLM_BATCH_INVARIANT=1 +inline bool vllm_is_batch_invariant() { + static bool cached = []() { + std::string env_key = "VLLM_BATCH_INVARIANT"; + const char* val = std::getenv(env_key.c_str()); + return (val && std::atoi(val) != 0) ? 1 : 0; + }(); + return cached; } } // namespace vllm diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp index 0f0cc34602b3..bb43aeee2eaf 100644 --- a/csrc/cpu/dnnl_helper.cpp +++ b/csrc/cpu/dnnl_helper.cpp @@ -187,7 +187,8 @@ template <> struct hash { size_t operator()( const MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const { - return hash()(val.b_n_size) ^ hash()(val.b_k_size); + return hash()(val.b_n_size) ^ hash()(val.b_k_size) ^ + hash()(static_cast(val.b_type)); } }; @@ -216,7 +217,8 @@ bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l, bool operator==(const MatMulPrimitiveHandler::ClassMatmulCacheKey& l, const MatMulPrimitiveHandler::ClassMatmulCacheKey& r) { - return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size; + return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size && + l.b_type == r.b_type; } bool operator==(const MatMulPrimitiveHandler::MSizeCacheKey& l, @@ -493,8 +495,10 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) { dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache( const MSizeCacheKey& key) { if (m_size_cache_.get() == nullptr) { - ClassMatmulCacheKey key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_}; - m_size_cache_ = get_matul_class_primitive_cache(key, primitive_cache_size_); + ClassMatmulCacheKey class_key = { + .b_n_size = b_n_size_, .b_k_size = b_k_size_, .b_type = b_type_}; + m_size_cache_ = + get_matul_class_primitive_cache(class_key, primitive_cache_size_); } return m_size_cache_->get_or_create(key, [&]() { dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); diff --git a/csrc/cpu/dnnl_helper.h b/csrc/cpu/dnnl_helper.h index f0cb197d81a3..58ffe7a19bd4 100644 --- a/csrc/cpu/dnnl_helper.h +++ b/csrc/cpu/dnnl_helper.h @@ -199,6 +199,7 @@ class MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler { struct ClassMatmulCacheKey { dnnl_dim_t b_n_size; dnnl_dim_t b_k_size; + dnnl::memory::data_type b_type; friend bool operator==(const ClassMatmulCacheKey& l, const ClassMatmulCacheKey& r); diff --git a/csrc/cub_helpers.h b/csrc/cub_helpers.h index 470a63a22cab..18e4e343ad8b 100644 --- a/csrc/cub_helpers.h +++ b/csrc/cub_helpers.h @@ -12,6 +12,7 @@ using CubMaxOp = cub::Max; #endif // CUB_VERSION #else #include -using CubAddOp = cub::Sum; -using CubMaxOp = cub::Max; +namespace cub = hipcub; +using CubAddOp = hipcub::Sum; +using CubMaxOp = hipcub::Max; #endif // USE_ROCM diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py index 5e742d0b0293..34fb64c413db 100644 --- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import enum -from typing import Union from cutlass_library import * @@ -22,7 +21,7 @@ class MixedInputKernelScheduleType(enum.Enum): TmaWarpSpecializedCooperative = enum_auto() -VLLMDataTypeNames: dict[Union[VLLMDataType, DataType], str] = { +VLLMDataTypeNames: dict[VLLMDataType | DataType, str] = { **DataTypeNames, # type: ignore **{ VLLMDataType.u4b8: "u4b8", @@ -30,7 +29,7 @@ class MixedInputKernelScheduleType(enum.Enum): }, } -VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = { +VLLMDataTypeTag: dict[VLLMDataType | DataType, str] = { **DataTypeTag, # type: ignore **{ VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t", @@ -38,7 +37,7 @@ class MixedInputKernelScheduleType(enum.Enum): }, } -VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = { +VLLMDataTypeSize: dict[VLLMDataType | DataType, int] = { **DataTypeSize, # type: ignore **{ VLLMDataType.u4b8: 4, @@ -46,7 +45,7 @@ class MixedInputKernelScheduleType(enum.Enum): }, } -VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = { +VLLMDataTypeVLLMScalarTypeTag: dict[VLLMDataType | DataType, str] = { VLLMDataType.u4b8: "vllm::kU4B8", VLLMDataType.u8b128: "vllm::kU8B128", DataType.u4: "vllm::kU4", @@ -57,7 +56,7 @@ class MixedInputKernelScheduleType(enum.Enum): DataType.bf16: "vllm::kBfloat16", } -VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = { +VLLMDataTypeTorchDataTypeTag: dict[VLLMDataType | DataType, str] = { DataType.u8: "at::ScalarType::Byte", DataType.s8: "at::ScalarType::Char", DataType.e4m3: "at::ScalarType::Float8_e4m3fn", @@ -67,9 +66,7 @@ class MixedInputKernelScheduleType(enum.Enum): DataType.f32: "at::ScalarType::Float", } -VLLMKernelScheduleTag: dict[ - Union[MixedInputKernelScheduleType, KernelScheduleType], str -] = { +VLLMKernelScheduleTag: dict[MixedInputKernelScheduleType | KernelScheduleType, str] = { **KernelScheduleTag, # type: ignore **{ MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized", # noqa: E501 diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 6c3685f6f7cd..8cfcf9f41283 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -2,6 +2,7 @@ #include "dispatch_utils.h" #include "cub_helpers.h" #include "core/batch_invariant.hpp" +#include "quantization/vectorization_utils.cuh" #include #include @@ -18,11 +19,22 @@ __global__ void rms_norm_kernel( const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; + const scalar_t* input_row = input + blockIdx.x * input_stride; - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float)input[blockIdx.x * input_stride + idx]; + constexpr int VEC_SIZE = 8; + auto vec_op = [&variance](const vec_n_t& vec) { +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + float x = static_cast(vec.val[i]); + variance += x * x; + } + }; + auto scalar_op = [&variance](const scalar_t& val) { + float x = static_cast(val); variance += x * x; - } + }; + vllm::vectorize_read_with_alignment( + input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; @@ -136,211 +148,6 @@ fused_add_rms_norm_kernel( } } -/* Function specialization in the case of FP16/BF16 tensors. - Additional optimizations we can make in this case are - packed and vectorized operations, which help with the - memory latency bottleneck. - - _f16VecPN struct extends _f16Vec to add operations specifically required for - polynomial normalization (poly norm). - The original _f16Vec does not include the sum-of-powers computation or - in-place polynomial normalization logic. */ -template -struct alignas(16) _f16VecPN : _f16Vec { - using Base = _f16Vec; - using Converter = typename Base::Converter; - using T1 = typename Base::T1; - using T2 = typename Base::T2; - using Base::data; - - __device__ auto sum_pows() const { - float s2 = 0.0f, s4 = 0.0f, s6 = 0.0f; - -#pragma unroll - for (int i = 0; i < width; i += 2) { - float2 z = Converter::convert(T2{data[i], data[i + 1]}); - float x2 = z.x * z.x; - float x4 = x2 * x2; - float x6 = x4 * x2; - - float y2 = z.y * z.y; - float y4 = y2 * y2; - float y6 = y4 * y2; - - s2 += x2 + y2; - s4 += x4 + y4; - s6 += x6 + y6; - } - return std::make_tuple(s2, s4, s6); - } - - __device__ void poly_norm_inplace(const float w2_inv_std, - const float w1_inv_std2, - const float w0_inv_std3, const float bias) { -#pragma unroll - for (int i = 0; i < width; i += 2) { - float2 z = Converter::convert(T2{data[i], data[i + 1]}); - - float x2 = z.x * z.x; - float x3 = x2 * z.x; - z.x = w2_inv_std * z.x + w1_inv_std2 * x2 + w0_inv_std3 * x3 + bias; - - float y2 = z.y * z.y; - float y3 = y2 * z.y; - z.y = w2_inv_std * z.y + w1_inv_std2 * y2 + w0_inv_std3 * y3 + bias; - - auto out = Converter::convert(z); - data[i] = out.x; - data[i + 1] = out.y; - } - } -}; - -template -__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> -poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] - const scalar_t* __restrict__ input, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [3] - const scalar_t* __restrict__ bias, // [1] - const float epsilon, const int hidden_size) { - // Sanity checks on our vector struct and type-punned pointer arithmetic - static_assert(std::is_pod_v<_f16VecPN>); - static_assert(sizeof(_f16VecPN) == sizeof(scalar_t) * width); - - /* These and the argument pointers are all declared `restrict` as they are - not aliased in practice. Argument pointers should not be dereferenced - in this kernel as that would be undefined behavior */ - auto* __restrict__ input_v = - reinterpret_cast*>(input); - const int vec_hidden_size = hidden_size / width; - float variance = 0.0f; - float variance2 = 0.0f; - float variance3 = 0.0f; - - for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { - int id = blockIdx.x * vec_hidden_size + idx; - _f16VecPN temp = input_v[id]; - auto [x2, x4, x6] = temp.sum_pows(); - - variance += x2; - variance2 += x4; - variance3 += x6; - } - - float3 thread_variances = make_float3(variance, variance2, variance3); - - struct SumOp { - __device__ float3 operator()(const float3& a, const float3& b) const { - return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); - } - }; - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStore; - float3 block_variances = - BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x); - - variance = block_variances.x; - variance2 = block_variances.y; - variance3 = block_variances.z; - - __shared__ float s_w2_inv_std; - __shared__ float s_w1_inv_std2; - __shared__ float s_w0_inv_std3; - __shared__ float s_bias; - - if (threadIdx.x == 0) { - float w0 = (float)weight[0]; - float w1 = (float)weight[1]; - float w2 = (float)weight[2]; - s_bias = (float)bias[0]; - - s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon); - s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon); - s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon); - } - __syncthreads(); - - auto* __restrict__ out_v = reinterpret_cast<_f16VecPN*>(out); - - for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { - int id = blockIdx.x * vec_hidden_size + idx; - _f16VecPN temp = input_v[id]; - temp.poly_norm_inplace(s_w2_inv_std, s_w1_inv_std2, s_w0_inv_std3, s_bias); - out_v[id] = temp; - } -} - -/* Generic poly_norm_kernel - The width field is not used here but necessary for other specializations. - */ -template -__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> -poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] - const scalar_t* __restrict__ input, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [3] - const scalar_t* __restrict__ bias, // [1] - const float epsilon, const int hidden_size) { - float variance = 0.0f; - float variance2 = 0.0f; - float variance3 = 0.0f; - - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * hidden_size + idx]; - float x2 = x * x; - float x4 = x2 * x2; - float x6 = x4 * x2; - - variance += x2; - variance2 += x4; - variance3 += x6; - } - - float3 thread_variances = make_float3(variance, variance2, variance3); - - struct SumOp { - __device__ float3 operator()(const float3& a, const float3& b) const { - return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); - } - }; - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStore; - float3 block_variances = - BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x); - - variance = block_variances.x; - variance2 = block_variances.y; - variance3 = block_variances.z; - - __shared__ float s_w2_inv_std; - __shared__ float s_w1_inv_std2; - __shared__ float s_w0_inv_std3; - __shared__ float s_bias; - - if (threadIdx.x == 0) { - float w0 = (float)weight[0]; - float w1 = (float)weight[1]; - float w2 = (float)weight[2]; - s_bias = (float)bias[0]; - - s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon); - s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon); - s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon); - } - __syncthreads(); - - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * hidden_size + idx]; - float x2 = x * x; - float x3 = x2 * x; - - out[blockIdx.x * hidden_size + idx] = - (scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 + - s_bias); - } -} - } // namespace vllm void rms_norm(torch::Tensor& out, // [..., hidden_size] @@ -352,18 +159,26 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] TORCH_CHECK(weight.is_contiguous()); int hidden_size = input.size(-1); - int num_tokens = input.numel() / hidden_size; - int64_t input_stride = input.stride(-2); + + // We cannot just use `input.stride(-2)` if the tensor is not row-major. + // Instead, we use a 2d view to get the second-innermost stride. + // That way the dimensions (except the last one) can be arbitrarily permuted. + torch::Tensor input_view = input.view({-1, hidden_size}); + + int num_tokens = input_view.numel() / hidden_size; + int64_t input_stride = input_view.stride(-2); dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { - vllm::rms_norm_kernel<<>>( - out.data_ptr(), input.data_ptr(), input_stride, - weight.data_ptr(), epsilon, num_tokens, hidden_size); - }); + VLLM_DISPATCH_FLOATING_TYPES( + input_view.scalar_type(), "rms_norm_kernel", [&] { + vllm::rms_norm_kernel<<>>( + out.data_ptr(), input_view.data_ptr(), + input_stride, weight.data_ptr(), epsilon, num_tokens, + hidden_size); + }); } #define LAUNCH_FUSED_ADD_RMS_NORM(width) \ @@ -380,6 +195,8 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] double epsilon) { + TORCH_CHECK(weight.scalar_type() == input.scalar_type()); + TORCH_CHECK(input.scalar_type() == residual.scalar_type()); TORCH_CHECK(residual.is_contiguous()); TORCH_CHECK(weight.is_contiguous()); int hidden_size = input.size(-1); @@ -414,7 +231,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] wt_ptr % req_alignment_bytes == 0; bool offsets_are_multiple_of_vector_width = hidden_size % vector_width == 0 && input_stride % vector_width == 0; - bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); + bool batch_invariant_launch = vllm::vllm_is_batch_invariant(); if (ptrs_are_aligned && offsets_are_multiple_of_vector_width && !batch_invariant_launch) { LAUNCH_FUSED_ADD_RMS_NORM(8); @@ -422,50 +239,3 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] LAUNCH_FUSED_ADD_RMS_NORM(0); } } - -#define LAUNCH_FUSED_POLY_NORM(width) \ - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \ - vllm::poly_norm_kernel<<>>( \ - out.data_ptr(), input.data_ptr(), \ - weight.data_ptr(), bias.data_ptr(), epsilon, \ - hidden_size); \ - }); - -void poly_norm(torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& weight, // [3] - torch::Tensor& bias, // [1] - double epsilon) { - TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.data_ptr() != input.data_ptr()); - - int hidden_size = input.size(-1); - int num_tokens = input.numel() / hidden_size; - - dim3 grid(num_tokens); - /* This kernel is memory-latency bound in many scenarios. - When num_tokens is large, a smaller block size allows - for increased block occupancy on CUs and better latency - hiding on global mem ops. */ - const int max_block_size = (num_tokens < 256) ? 1024 : 256; - dim3 block(std::min(hidden_size, max_block_size)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - /*If the tensor types are FP16/BF16, try to use the optimized kernel - with packed + vectorized ops. - Max optimization is achieved with a width-8 vector of FP16/BF16s - since we can load at most 128 bits at once in a global memory op. - However, this requires each tensor's data to be aligned to 16 - bytes. - */ - auto inp_ptr = reinterpret_cast(input.data_ptr()); - auto out_ptr = reinterpret_cast(out.data_ptr()); - bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0; - bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); - if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) { - LAUNCH_FUSED_POLY_NORM(8); - } else { - LAUNCH_FUSED_POLY_NORM(0); - } -} diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index 58c3d9c0981a..0f7f034ee180 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -6,10 +6,11 @@ */ #include "type_convert.cuh" -#include "quantization/fp8/common.cuh" +#include "quantization/w8a8/fp8/common.cuh" #include "dispatch_utils.h" #include "cub_helpers.h" #include "core/batch_invariant.hpp" +#include "quantization/vectorization_utils.cuh" #include #include @@ -28,10 +29,22 @@ __global__ void rms_norm_static_fp8_quant_kernel( __shared__ float s_variance; float variance = 0.0f; - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float)input[blockIdx.x * input_stride + idx]; + const scalar_t* input_row = input + blockIdx.x * input_stride; + + constexpr int VEC_SIZE = 8; + auto vec_op = [&variance](const vec_n_t& vec) { +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + float x = static_cast(vec.val[i]); + variance += x * x; + } + }; + auto scalar_op = [&variance](const scalar_t& val) { + float x = static_cast(val); variance += x * x; - } + }; + vllm::vectorize_read_with_alignment( + input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; @@ -216,6 +229,8 @@ void fused_add_rms_norm_static_fp8_quant( double epsilon) { TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(residual.is_contiguous()); + TORCH_CHECK(residual.scalar_type() == input.scalar_type()); + TORCH_CHECK(weight.scalar_type() == input.scalar_type()); int hidden_size = input.size(-1); int input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; @@ -241,7 +256,7 @@ void fused_add_rms_norm_static_fp8_quant( auto wt_ptr = reinterpret_cast(weight.data_ptr()); bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; - bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); + bool batch_invariant_launch = vllm::vllm_is_batch_invariant(); if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 && !batch_invariant_launch) { LAUNCH_FUSED_ADD_RMS_NORM(8); diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 629348bf8876..b3d0c0aa58e9 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -8,12 +8,77 @@ #include "../cuda_compat.h" #include "../dispatch_utils.h" +#include "core/math.hpp" #define CEILDIV(x, y) (((x) + (y) - 1) / (y)) namespace vllm { namespace moe { +namespace batched_moe_align_block_size { + +// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel. +static constexpr int32_t num_threads = 1024; +static constexpr int32_t num_blocks = 1; +__global__ void batched_moe_align_block_size_kernel( + int32_t const num_batches, int32_t const max_tokens_per_batch, + int32_t const block_size, int32_t const* __restrict__ batch_num_tokens, + int32_t* __restrict__ sorted_ids, int32_t* __restrict__ block_ids, + int32_t* __restrict__ num_tokens_post_pad) { + // TODO(varun): This is a naive implementation. Could be optimized. + + size_t const batch_id = threadIdx.x; + size_t const stride = blockDim.x * gridDim.x; + int32_t const num_blocks_per_batch = + CEILDIV(max_tokens_per_batch, block_size); + int32_t const sorted_ids_size = + num_blocks_per_batch * num_batches * block_size; + int32_t const block_ids_size = sorted_ids_size / block_size; + int32_t const SENTINEL = + num_batches * max_tokens_per_batch; // To denote invalid entries. + // Intialize sorted_ids + for (size_t i = threadIdx.x; i < sorted_ids_size; i += stride) { + sorted_ids[i] = SENTINEL; + } + // Intialize expert_ids with -1 + for (size_t i = threadIdx.x; i < block_ids_size; i += stride) { + block_ids[i] = -1; + } + + int32_t b_num_tokens = 0; + if (batch_id < num_batches) { + b_num_tokens = batch_num_tokens[batch_id]; + } + int32_t const ceil_b_num_tokens = + CEILDIV(b_num_tokens, block_size) * block_size; + + // Compute prefix sum over token counts per expert + using BlockScan = cub::BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + int cumsum_val; + BlockScan(temp_storage).ExclusiveSum(ceil_b_num_tokens, cumsum_val); + __syncthreads(); + + bool const is_last_batch = batch_id == (num_batches - 1); + if (is_last_batch) { + *num_tokens_post_pad = cumsum_val + ceil_b_num_tokens; + } + + if (batch_id < num_batches) { + int32_t const batch_offset = batch_id * max_tokens_per_batch; + for (size_t i = 0; i < b_num_tokens; ++i) { + sorted_ids[cumsum_val + i] = batch_offset + i; + } + + int32_t const block_start = cumsum_val / block_size; + int32_t const num_blocks = ceil_b_num_tokens / block_size; + for (size_t i = 0; i < num_blocks; ++i) { + block_ids[block_start + i] = batch_id; + } + } +} +} // namespace batched_moe_align_block_size + template __global__ void moe_align_block_size_kernel( const scalar_t* __restrict__ topk_ids, @@ -280,6 +345,33 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, }); } +void batched_moe_align_block_size(int64_t max_tokens_per_batch, + int64_t block_size, + torch::Tensor const& batch_num_tokens, + torch::Tensor sorted_ids, + torch::Tensor batch_ids, + torch::Tensor num_tokens_post_pad) { + namespace batched_kernel = vllm::moe::batched_moe_align_block_size; + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + int32_t const B = batch_num_tokens.size(0); + int32_t const num_blocks_per_batch = + round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size; + int32_t const num_blocks = num_blocks_per_batch * B; + int64_t const sorted_ids_size = num_blocks * block_size; + + TORCH_CHECK(sorted_ids.size(0) == sorted_ids_size); + TORCH_CHECK(batch_ids.size(0) == sorted_ids_size / block_size); + TORCH_CHECK(num_tokens_post_pad.size(0) == 1); + TORCH_CHECK(B <= batched_kernel::num_threads); + + batched_kernel::batched_moe_align_block_size_kernel<<< + batched_kernel::num_blocks, batched_kernel::num_threads, 0, stream>>>( + B, max_tokens_per_batch, block_size, batch_num_tokens.data_ptr(), + sorted_ids.data_ptr(), batch_ids.data_ptr(), + num_tokens_post_pad.data_ptr()); +} + void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] torch::Tensor& output) // [num_tokens, hidden_size] { diff --git a/csrc/moe/moe_lora_align_sum_kernels.cu b/csrc/moe/moe_lora_align_sum_kernels.cu new file mode 100644 index 000000000000..e76d1c366785 --- /dev/null +++ b/csrc/moe/moe_lora_align_sum_kernels.cu @@ -0,0 +1,169 @@ +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../cuda_compat.h" +#include "../dispatch_utils.h" +#include "core/math.hpp" + +namespace { + +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, + int32_t col) { + return row * total_col + col; +} + +} // namespace + +// TODO: Refactor common parts with moe_align_sum_kernels +template +__global__ void moe_lora_align_sum_kernel( + scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping, + int64_t block_size, int num_experts, int max_loras, size_t numel, + int max_num_tokens_padded, int max_num_m_blocks, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int topk_num, int32_t* total_tokens_post_pad) { + const size_t tokens_per_thread = div_ceil(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + int lora_id = blockIdx.x; + extern __shared__ int32_t shared_mem[]; + int32_t* cumsum = shared_mem; + token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1); + + // Initialize sorted_token_ids with numel + for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { + sorted_token_ids[lora_id * max_num_tokens_padded + it] = numel; + } + + // Initialize expert_ids with -1 + for (size_t it = threadIdx.x; it < max_num_m_blocks; it += blockDim.x) { + expert_ids[lora_id * max_num_m_blocks + it] = -1; + } + + // Initialize total_tokens_post_pad with 0 + if (threadIdx.x == 0) { + total_tokens_post_pad[lora_id] = 0; + } + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + } + + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int mask = token_lora_mapping[i / topk_num] == lora_id; + int idx = index(num_experts, threadIdx.x + 1, topk_ids[i]); + tokens_cnts[idx] += mask; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + if (threadIdx.x < num_experts) { + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + } + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + + div_ceil(tokens_cnts[index(num_experts, blockDim.x, i - 1)], + block_size) * + block_size; + } + total_tokens_post_pad[lora_id] = static_cast(cumsum[num_experts]); + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding + * blocks and stores the corresponding expert_id for each block. + */ + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[index(max_num_m_blocks, lora_id, i / block_size)] = + threadIdx.x; + } + } + + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and + * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens + * processed by the expert with expert_id within the current thread's token + * shard. + */ + int32_t rank_post_pad = + tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + + cumsum[expert_id]; + + int mask = (int)token_lora_mapping[i / topk_num] == lora_id; + atomicAdd( + &sorted_token_ids[index(max_num_tokens_padded, lora_id, rank_post_pad)], + (i - numel) * mask); + tokens_cnts[index(num_experts, threadIdx.x, expert_id)] += mask; + } +} + +void moe_lora_align_block_size(torch::Tensor topk_ids, + torch::Tensor token_lora_mapping, + int64_t num_experts, int64_t block_size, + int64_t max_loras, int64_t max_num_tokens_padded, + int64_t max_num_m_blocks, + torch::Tensor sorted_token_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad) { + const int topk_num = topk_ids.size(1); + + TORCH_CHECK(block_size > 0, "block_size should be greater than 0. "); + + int device_max_shared_mem; + auto dev = topk_ids.get_device(); + cudaDeviceGetAttribute(&device_max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const int32_t num_thread = max((int32_t)num_experts, 128); // WARP_SIZE, + TORCH_CHECK(num_thread <= 1024, + "num_thread must be less than 1024, " + "and fallback is not implemented yet."); + const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) + + (num_experts + 1) * sizeof(int32_t); + + if (shared_mem > device_max_shared_mem) { + TORCH_CHECK(false, + "Shared memory usage exceeds device limit, and global memory " + "fallback is not implemented yet."); + } + + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] { + dim3 blockDim(num_thread); + auto kernel = moe_lora_align_sum_kernel; + AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( + (void*)kernel, shared_mem)); + kernel<<>>( + topk_ids.data_ptr(), + token_lora_mapping.data_ptr(), block_size, num_experts, + max_loras, topk_ids.numel(), max_num_tokens_padded, + max_num_m_blocks, sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), topk_num, + num_tokens_post_pad.data_ptr()); + }); +} \ No newline at end of file diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 92fc280b362b..e4bf0aa99421 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -4,7 +4,7 @@ void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& token_expert_indices, - torch::Tensor& gating_output); + torch::Tensor& gating_output, bool renormalize); void moe_sum(torch::Tensor& input, torch::Tensor& output); @@ -12,6 +12,22 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); + +void batched_moe_align_block_size(int64_t max_tokens_per_batch, + int64_t block_size, + torch::Tensor const& expert_num_tokens, + torch::Tensor sorted_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad); + +void moe_lora_align_block_size(torch::Tensor topk_ids, + torch::Tensor token_lora_mapping, + int64_t num_experts, int64_t block_size, + int64_t max_loras, int64_t max_num_tokens_padded, + int64_t max_num_m_blocks, + torch::Tensor sorted_token_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad); #ifndef USE_ROCM torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index eca021f1c186..af6e6fcd482c 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -16,12 +16,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include #include #include "../cuda_compat.h" #include "../cub_helpers.h" -#include "../core/batch_invariant.hpp" + +#ifndef USE_ROCM + #include + #include +#else + #include + #include + typedef __hip_bfloat16 __nv_bfloat16; + typedef __hip_bfloat162 __nv_bfloat162; +#endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -37,16 +47,27 @@ template < /// Alignment requirement in bytes int Alignment = sizeof(T) * N > -class alignas(Alignment) AlignedArray { - float data[N]; +struct alignas(Alignment) AlignedArray { + T data[N]; }; +template +__device__ __forceinline__ float toFloat(T value) { + if constexpr (std::is_same_v) { + return value; + } else if constexpr (std::is_same_v) { + return __bfloat162float(value); + } else if constexpr (std::is_same_v) { + return __half2float(value); + } +} + // ====================== Softmax things =============================== // We have our own implementation of softmax here so we can support transposing the output // in the softmax kernel when we extend this module to support expert-choice routing. -template +template __launch_bounds__(TPB) __global__ - void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols) + void moeSoftmax(const InputType* input, const bool* finished, float* output, const int num_cols) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; @@ -67,7 +88,8 @@ __launch_bounds__(TPB) __global__ for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; - threadData = max(static_cast(input[idx]), threadData); + const float val = toFloat(input[idx]); + threadData = max(val, threadData); } const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, CubMaxOp()); @@ -82,7 +104,8 @@ __launch_bounds__(TPB) __global__ for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; - threadData += exp((static_cast(input[idx]) - float_max)); + const float val = toFloat(input[idx]); + threadData += expf(val - float_max); } const auto Z = BlockReduce(tmpStorage).Reduce(threadData, CubAddOp()); @@ -96,8 +119,9 @@ __launch_bounds__(TPB) __global__ for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; - const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; - output[idx] = val; + const float val = toFloat(input[idx]); + const float softmax_val = expf(val - float_max) * normalizing_factor; + output[idx] = softmax_val; } } @@ -111,7 +135,8 @@ __launch_bounds__(TPB) __global__ void moeTopK( const int num_experts, const int k, const int start_expert, - const int end_expert) + const int end_expert, + const bool renormalize) { using cub_kvp = cub::KeyValuePair; @@ -126,6 +151,7 @@ __launch_bounds__(TPB) __global__ void moeTopK( const bool row_is_active = finished ? !finished[block_row] : true; const int thread_read_offset = blockIdx.x * num_experts; + float selected_sum = 0.f; for (int k_idx = 0; k_idx < k; ++k_idx) { thread_kvp.key = 0; @@ -164,9 +190,23 @@ __launch_bounds__(TPB) __global__ void moeTopK( indices[idx] = should_process_row ? (expert - start_expert) : num_experts; assert(indices[idx] >= 0); source_rows[idx] = k_idx * num_rows + block_row; + if (renormalize) { + selected_sum += result_kvp.value; + } } __syncthreads(); } + + // Renormalize the k weights for this row to sum to 1, if requested. + if (renormalize) { + if (threadIdx.x == 0) { + const float denom = selected_sum > 0.f ? selected_sum : 1.f; + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int idx = k * block_row + k_idx; + output[idx] = output[idx] / denom; + } + } + } } // ====================== TopK softmax things =============================== @@ -185,21 +225,30 @@ __launch_bounds__(TPB) __global__ void moeTopK( 2) This implementation assumes k is small, but will work for any k. */ -template +template __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ - void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices, - int* source_rows, const int k, const int start_expert, const int end_expert) + void topkGatingSoftmax(const InputType* input, const bool* finished, float* output, const int num_rows, IndType* indices, + int* source_rows, const int k, const int start_expert, const int end_expert, const bool renormalize) { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v, + "InputType must be float, __nv_bfloat16, or __half"); + // We begin by enforcing compile time assertions and setting up compile time constants. static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); // Number of bytes each thread pulls in per load - static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType); static constexpr int ELTS_PER_ROW = NUM_EXPERTS; static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + if constexpr (std::is_same_v || std::is_same_v) { + static_assert(ELTS_PER_LDG == 1 || ELTS_PER_LDG % 2 == 0, + "ELTS_PER_LDG must be 1 or even for 16-bit conversion"); + } + // Restrictions based on previous section. static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); @@ -237,27 +286,71 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the // row it will read. - const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + const InputType* thread_row_ptr = input + thread_row * ELTS_PER_ROW; // Now, we compute the group each thread belong to in order to determine the first column to start loads. const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; - const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; - - // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, - // this can support all powers of 2 up to 16. - // NOTE(woosuk): The original implementation uses CUTLASS aligned array here. - // We defined our own aligned array and use it here to avoid the dependency on CUTLASS. - using AccessType = AlignedArray; + const InputType* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; // Finally, we pull in the data from global mem float row_chunk[VPT]; - AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk); - const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); + + // NOTE(zhuhaoran): dispatch different input types loading, BF16/FP16 convert to float + if constexpr (std::is_same_v) { + using VecType = AlignedArray; + VecType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk); + const VecType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); #pragma unroll - for (int ii = 0; ii < LDG_PER_THREAD; ++ii) - { - row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + } else if constexpr (std::is_same_v) { + if constexpr (ELTS_PER_LDG >= 2) { + using VecType = AlignedArray<__nv_bfloat16, ELTS_PER_LDG>; + float2* row_chunk_f2 = reinterpret_cast(row_chunk); + const VecType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + int base_idx_f2 = ii * ELTS_PER_LDG / 2; +#pragma unroll + for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) { + row_chunk_f2[base_idx_f2 + jj] = __bfloat1622float2( + *reinterpret_cast(vec.data + jj * 2) + ); + } + } + } else { // ELTS_PER_LDG == 1 +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + const __nv_bfloat16* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW; + row_chunk[ii] = __bfloat162float(*scalar_ptr); + } + } + } else if constexpr (std::is_same_v) { + if constexpr (ELTS_PER_LDG >= 2) { + using VecType = AlignedArray<__half, ELTS_PER_LDG>; + float2* row_chunk_f2 = reinterpret_cast(row_chunk); + const VecType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + int base_idx_f2 = ii * ELTS_PER_LDG / 2; +#pragma unroll + for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) { + row_chunk_f2[base_idx_f2 + jj] = __half22float2( + *reinterpret_cast(vec.data + jj * 2) + ); + } + } + } else { // ELTS_PER_LDG == 1 +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + const __half* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW; + row_chunk[ii] = __half2float(*scalar_ptr); + } + } } // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just @@ -311,6 +404,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ int start_col = first_elt_read_by_thread; static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + float selected_sum = 0.f; for (int k_idx = 0; k_idx < k; ++k_idx) { // First, each thread does the local argmax @@ -364,6 +458,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ output[idx] = max_val; indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; source_rows[idx] = k_idx * num_rows + thread_row; + if (renormalize) { + selected_sum += max_val; + } } // Finally, we clear the value in the thread with the current max if there is another iteration to run. @@ -381,15 +478,28 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ } } } + + // Renormalize the k weights for this row to sum to 1, if requested. + if (renormalize) { + if (thread_group_idx == 0) + { + const float denom = selected_sum > 0.f ? selected_sum : 1.f; + for (int k_idx = 0; k_idx < k; ++k_idx) + { + const int idx = k * thread_row + k_idx; + output[idx] = output[idx] / denom; + } + } + } } namespace detail { // Constructs some constants needed to partition the work across threads at compile time. -template +template struct TopkConstants { - static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType); static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0, ""); static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM)); static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; @@ -398,21 +508,21 @@ struct TopkConstants }; } // namespace detail -template -void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices, - int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) +template +void topkGatingSoftmaxLauncherHelper(const InputType* input, const bool* finished, float* output, IndType* indices, + int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, const bool renormalize, + cudaStream_t stream) { - static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); - using Constants = detail::TopkConstants; + static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(InputType) * EXPERTS); + using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; - const bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); - const int num_warps = batch_invariant_launch ? 32 : (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB); - topkGatingSoftmax<<>>( - input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); + topkGatingSoftmax<<>>( + input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert, renormalize); } #ifndef USE_ROCM @@ -420,26 +530,26 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f static_assert(WARP_SIZE == 32, \ "Unsupported warp size. Only 32 is supported for CUDA"); \ topkGatingSoftmaxLauncherHelper( \ - gating_output, nullptr, topk_weights, topk_indices, \ - token_expert_indices, num_tokens, topk, 0, num_experts, stream); + gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \ + num_tokens, topk, 0, num_experts, renormalize, stream); #else #define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ if (WARP_SIZE == 64) { \ topkGatingSoftmaxLauncherHelper( \ - gating_output, nullptr, topk_weights, topk_indices, \ - token_expert_indices, num_tokens, topk, 0, num_experts, stream); \ + gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \ + num_tokens, topk, 0, num_experts, renormalize, stream); \ } else if (WARP_SIZE == 32) { \ topkGatingSoftmaxLauncherHelper( \ - gating_output, nullptr, topk_weights, topk_indices, \ - token_expert_indices, num_tokens, topk, 0, num_experts, stream); \ + gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \ + num_tokens, topk, 0, num_experts, renormalize, stream); \ } else { \ assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \ } #endif -template +template void topkGatingSoftmaxKernelLauncher( - const float* gating_output, + const InputType* gating_output, float* topk_weights, IndType* topk_indices, int* token_expert_indices, @@ -447,11 +557,15 @@ void topkGatingSoftmaxKernelLauncher( const int num_tokens, const int num_experts, const int topk, + const bool renormalize, cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16; #ifndef USE_ROCM - static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8; + // for bfloat16 dtype, we need 4 bytes loading to make sure num_experts + // elements can be loaded by a warp + static constexpr int BYTES_PER_LDG_MULTIPLE_64 = + (std::is_same_v || std::is_same_v) ? 4 : 8; #endif switch (num_experts) { case 1: @@ -508,11 +622,11 @@ void topkGatingSoftmaxKernelLauncher( TORCH_CHECK(softmax_workspace != nullptr, "softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64."); static constexpr int TPB = 256; - moeSoftmax<<>>( + moeSoftmax<<>>( gating_output, nullptr, softmax_workspace, num_experts); moeTopK<<>>( softmax_workspace, nullptr, topk_weights, topk_indices, token_expert_indices, - num_experts, topk, 0, num_experts); + num_experts, topk, 0, num_experts, renormalize); } } } @@ -520,11 +634,50 @@ void topkGatingSoftmaxKernelLauncher( } // namespace moe } // namespace vllm + +template +void dispatch_topk_softmax_launch( + torch::Tensor& gating_output, + torch::Tensor& topk_weights, + torch::Tensor& topk_indices, + torch::Tensor& token_expert_indices, + torch::Tensor& softmax_workspace, + int num_tokens, int num_experts, int topk, bool renormalize, cudaStream_t stream) +{ + if (topk_indices.scalar_type() == at::ScalarType::Int) { + vllm::moe::topkGatingSoftmaxKernelLauncher( + reinterpret_cast(gating_output.data_ptr()), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, num_experts, topk, renormalize, stream); + } else if (topk_indices.scalar_type() == at::ScalarType::UInt32) { + vllm::moe::topkGatingSoftmaxKernelLauncher( + reinterpret_cast(gating_output.data_ptr()), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, num_experts, topk, renormalize, stream); + } else { + TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long); + vllm::moe::topkGatingSoftmaxKernelLauncher( + reinterpret_cast(gating_output.data_ptr()), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, num_experts, topk, renormalize, stream); + } +} + void topk_softmax( torch::Tensor& topk_weights, // [num_tokens, topk] torch::Tensor& topk_indices, // [num_tokens, topk] torch::Tensor& token_expert_indices, // [num_tokens, topk] - torch::Tensor& gating_output) // [num_tokens, num_experts] + torch::Tensor& gating_output, // [num_tokens, num_experts] + bool renormalize) { const int num_experts = gating_output.size(-1); const auto num_tokens = gating_output.numel() / num_experts; @@ -536,45 +689,19 @@ void topk_softmax( const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); - - if(topk_indices.scalar_type() == at::ScalarType::Int) - { - vllm::moe::topkGatingSoftmaxKernelLauncher( - gating_output.data_ptr(), - topk_weights.data_ptr(), - topk_indices.data_ptr(), - token_expert_indices.data_ptr(), - softmax_workspace.data_ptr(), - num_tokens, - num_experts, - topk, - stream); - } - else if (topk_indices.scalar_type() == at::ScalarType::UInt32) - { - vllm::moe::topkGatingSoftmaxKernelLauncher( - gating_output.data_ptr(), - topk_weights.data_ptr(), - topk_indices.data_ptr(), - token_expert_indices.data_ptr(), - softmax_workspace.data_ptr(), - num_tokens, - num_experts, - topk, - stream); - } - else { - TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long); - vllm::moe::topkGatingSoftmaxKernelLauncher( - gating_output.data_ptr(), - topk_weights.data_ptr(), - topk_indices.data_ptr(), - token_expert_indices.data_ptr(), - softmax_workspace.data_ptr(), - num_tokens, - num_experts, - topk, - stream); + const auto workspace_options = gating_output.options().dtype(at::ScalarType::Float); + torch::Tensor softmax_workspace = torch::empty({workspace_size}, workspace_options); + + if (gating_output.scalar_type() == at::ScalarType::Float) { + dispatch_topk_softmax_launch(gating_output, topk_weights, topk_indices, + token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream); + } else if (gating_output.scalar_type() == at::ScalarType::Half) { + dispatch_topk_softmax_launch<__half>(gating_output, topk_weights, topk_indices, + token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream); + } else if (gating_output.scalar_type() == at::ScalarType::BFloat16) { + dispatch_topk_softmax_launch<__nv_bfloat16>(gating_output, topk_weights, topk_indices, + token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream); + } else { + TORCH_CHECK(false, "Unsupported gating_output data type: ", gating_output.scalar_type()); } } diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 8f33d6cd666f..c08a543908ef 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -5,7 +5,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Apply topk softmax to the gating outputs. m.def( "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " - "token_expert_indices, Tensor gating_output) -> ()"); + "token_expert_indices, Tensor gating_output, bool renormalize) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); // Calculate the result of moe by summing up the partial results @@ -22,6 +22,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " Tensor! num_tokens_post_pad) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + // Aligning the number of tokens to be processed by each expert such + // that it is divisible by the block size, but for the batched case. + m.def( + "batched_moe_align_block_size(int max_tokens_per_batch," + " int block_size, Tensor expert_num_tokens," + " Tensor! sorted_token_ids," + " Tensor! experts_ids," + " Tensor! num_tokens_post_pad) -> ()"); + m.impl("batched_moe_align_block_size", torch::kCUDA, + &batched_moe_align_block_size); + + // Aligning the number of tokens to be processed by each expert such + // that it is divisible by the block size. + m.def( + "moe_lora_align_block_size(Tensor topk_ids," + " Tensor token_lora_mapping," + " int num_experts," + " int block_size, int max_loras, " + " int max_num_tokens_padded, " + " int max_num_m_blocks, " + " Tensor !sorted_token_ids," + " Tensor !experts_ids," + " Tensor !num_tokens_post_pad) -> () "); + m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size); + #ifndef USE_ROCM m.def( "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " diff --git a/csrc/ops.h b/csrc/ops.h index 9dd302faf5b8..0bed7492f661 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -92,9 +92,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); -void poly_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, - torch::Tensor& bias, double epsilon); - void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& prompt_mask, const torch::Tensor& output_mask, @@ -102,8 +99,11 @@ void apply_repetition_penalties_(torch::Tensor& logits, void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, const torch::Tensor& rowEnds, torch::Tensor& indices, - torch::Tensor& values, int64_t numRows, int64_t stride0, - int64_t stride1); + int64_t numRows, int64_t stride0, int64_t stride1); + +void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, + const torch::Tensor& seq_lens, torch::Tensor& indices, + int64_t numRows, int64_t stride0, int64_t stride1); void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale, @@ -138,12 +138,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& input_global_scale); #endif -void silu_mul_fp8_quant_deep_gemm_cuda( +void persistent_masked_m_silu_mul_quant( const at::Tensor& input, // (E, T, 2*H) const at::Tensor& counts, // (E) at::Tensor& y_q, // (E, T, H) [OUT] at::Tensor& y_s, // (E, T, H//group_size) [OUT] - int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens); + bool use_ue8m0); void mul_and_silu(torch::Tensor& out, torch::Tensor& input); @@ -307,7 +307,7 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, int64_t bit); + bool use_exllama, bool use_v2_format, int64_t bit); void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index b94cc9ce5086..6fcd246f63c5 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -7,7 +7,7 @@ #include "../cuda_compat.h" #include "dispatch_utils.h" -#include "quantization/fp8/common.cuh" +#include "quantization/w8a8/fp8/common.cuh" #include @@ -114,13 +114,22 @@ __global__ void act_and_mul_quant_kernel( } __device__ __forceinline__ float silu(float x) { - return (__fdividef(x, (1.f + expf(-x)))); + return __fdividef(x, (1.f + expf(-x))); } __device__ __forceinline__ float2 silu2(float2 x) { return make_float2(silu(x.x), silu(x.y)); } +__device__ __forceinline__ __nv_bfloat162 silu2_v2(float2 x) { +#ifndef USE_ROCM + return make_bfloat162(__float2bfloat16_rn(silu(x.x)), + __float2bfloat16_rn(silu(x.y))); +#else + return __float22bfloat162_rn(make_float2(silu(x.x), silu(x.y))); +#endif +} + #ifndef USE_ROCM __device__ __forceinline__ float warp_max(float v) { static constexpr unsigned FULL_MASK = 0xffffffffu; @@ -223,224 +232,308 @@ constexpr __nv_bfloat16 get_fp8_min() { return __nv_bfloat16(__nv_bfloat16_raw{.x = 50032}); } } -#ifndef USE_ROCM -template +__device__ __forceinline__ int warp_expert_search( + int idx, int n, const Idx_t* __restrict__ input, Idx_t val) { + const Idx_t* input_ptr = input + idx; + int base_offset = 0; + + for (;;) { + bool move_on = (idx < n && *input_ptr <= val); + + unsigned mask = __ballot_sync(0xffffffff, move_on); + + if (mask != 0xffffffffu) { + int last_lane = 31 - __clz(mask); + return base_offset + last_lane; + } + + input_ptr += 32; + base_offset += 32; + idx += 32; + } +} + +template +__device__ __forceinline__ void token_bounds(int32_t n_tokens, + int32_t worker_id, + int32_t& n_tokens_lower, + int32_t& n_tokens_upper) { + if (n_tokens < num_parallel_tokens && worker_id < n_tokens) { + if (worker_id >= num_parallel_tokens) return; + n_tokens_lower = worker_id; + n_tokens_upper = worker_id + 1; + } else { + int32_t chunk_size = n_tokens / num_parallel_tokens; + int32_t residual = n_tokens - chunk_size * num_parallel_tokens; + auto calc_id = [&](int32_t id) { + if (id < residual) + return min(n_tokens, id * (chunk_size + 1)); + else + return min(n_tokens, id * chunk_size + residual); + }; + n_tokens_lower = calc_id(worker_id); + n_tokens_upper = calc_id(worker_id + 1); + } +} + +template __global__ void silu_mul_fp8_quant_deep_gemm_kernel( const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q, - float* __restrict__ _y_s, const int32_t* __restrict__ counts, - + float* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert, // sizes - int H, int G, - + Idx_t E, Idx_t T, Idx_t H, // strides (in elements) Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e, Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t, Idx_t stride_ys_g, Idx_t stride_counts_e) { - static constexpr __nv_bfloat16 fp8_min = get_fp8_min(); - static constexpr __nv_bfloat16 fp8_max = get_fp8_max(); - // We assign EPS with its 16-bit unsigned counterpart to allow constexpr. - static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996}); +#ifndef USE_ROCM + static constexpr int NUM_WARPS = THREADS / WARP_SIZE; + + static constexpr int LOAD_STAGE_SIZE = 2 * GROUP_SIZE / 8; + static constexpr int LOAD_STAGE_MOD = NUM_STAGES * LOAD_STAGE_SIZE; - // We pack 8 16-bit bfloat16 values into a 128-bit __int128_t. - static constexpr int32_t BFLOAT16_PER_GROUP = 8; + static constexpr int COMPUTE_STAGE_SIZE = 2 * GROUP_SIZE / 4; + static constexpr int COMPUTE_STAGE_MOD = COMPUTE_STAGE_SIZE * NUM_STAGES; - // We split the shared memory in half, corresponding to gate and up matrices: - // [...gate_i, ...up_i] where 0 <= i < stages. - static constexpr int32_t S_NUM_128 = - 2u * (GROUP_SIZE / BFLOAT16_PER_GROUP) * NUM_WARPS * NUM_STAGES; - static constexpr auto THREAD_COUNT = NUM_WARPS * WARP_SIZE; - static constexpr int HALF_THREAD_COUNT = THREAD_COUNT / 2; - static constexpr int32_t S_NUM_64 = S_NUM_128 * 2; - __shared__ __int128_t __align__(16) s_buff_128[S_NUM_128]; + extern __shared__ __align__(16) __int128_t smem_128[]; - const int32_t tid = threadIdx.x; - const int32_t warp_id = tid / WARP_SIZE; - const int32_t lane_id = tid % WARP_SIZE; + int* s_expert_offsets = + reinterpret_cast(smem_128 + (SMEM_SIZE_BYTES_Y / 16)); - auto s_buff_compute_32 = reinterpret_cast<__nv_bfloat162*>(s_buff_128); + static constexpr __nv_bfloat16 fp8_min = get_fp8_min(); + static constexpr __nv_bfloat16 fp8_max = get_fp8_max(); + // We assign EPS with it's 16-bit unsigned counterpart to allow constexpr. + static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996}); + int tid = threadIdx.x; + int warp_id = tid >> 5; + int lane_id = tid & 0x1f; + + int running_sum{}; + if (!warp_id) { + for (int i = 0; i < E; i += WARP_SIZE) { + bool valid = (i + threadIdx.x) < E; + int value = + (valid ? tokens_per_expert[i + threadIdx.x * stride_counts_e] : 0) + + (!lane_id ? running_sum : 0); + + for (int offset = 1; offset < 32; offset *= 2) { + int n = __shfl_up_sync(0xFFFFFFFFu, value, offset); + if (lane_id >= offset) value += n; + } - // block handles one (expert e, group g) - int32_t pid = blockIdx.x; - int32_t e = pid / G; - int32_t g = pid % G; + if (valid) { + s_expert_offsets[i + threadIdx.x + 1] = value; + } - const int32_t n_tokens = counts[e * stride_counts_e]; + running_sum = __shfl_sync(0xFFFFFFFFu, value, WARP_SIZE - 1); + } - if (!n_tokens) { - return; // Exit ASAP. + if (!lane_id) { + s_expert_offsets[0] = 0; + } } - const Idx_t stride_i_t_128 = stride_i_t / 8u; + __syncthreads(); + + int32_t total_tokens = s_expert_offsets[E]; - int32_t n_tokens_lower, n_tokens_upper; + const int warp_position_yq = warp_id * (H / NUM_WARPS); + const int warp_position_scales = warp_id * (H / (GROUP_SIZE * NUM_WARPS)); + // A single block will handle tokens_per_block tokens. // Each block i iterates over tokens of a slice of n_tokens = // expert_counts[i], with the size of chunk being // (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of // updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling. - if (n_tokens < NUM_PARALLEL_TOKENS && blockIdx.y < n_tokens) { - // Specialize this, but can be likely fused. - if (blockIdx.y >= NUM_PARALLEL_TOKENS) { - return; - } - n_tokens_lower = blockIdx.y; - n_tokens_upper = blockIdx.y + 1; - } else { - auto chunk_size = n_tokens / NUM_PARALLEL_TOKENS; - auto residual = n_tokens - chunk_size * NUM_PARALLEL_TOKENS; - auto calc_id = [&](int32_t id) { - if (id < residual) { - return min(n_tokens, id * (chunk_size + 1)); - } else { - return min(n_tokens, id * chunk_size + residual); - } - }; - n_tokens_lower = calc_id(blockIdx.y); - n_tokens_upper = calc_id(blockIdx.y + 1); - } - if (n_tokens_lower >= n_tokens_upper) { + // Each warp will get space to store its hidden dim for gate and up. + __int128_t* s_hidden_load = smem_128 + warp_id * ((2 * 128 / 8) * NUM_STAGES); + __int128_t* smem_load_ptr = s_hidden_load + lane_id; + + const __nv_bfloat16 fp8_inv = __hdiv(__float2bfloat16(1.f), fp8_max); + + int32_t compute_pipeline_offset_64 = 0; + int32_t load_stage_offset{}; + const __nv_bfloat16 one_bf16 = __float2bfloat16_rn(1.f); + + __int64_t* smem_compute_ptr = reinterpret_cast<__int64_t*>(smem_128) + + warp_id * (2 * (GROUP_SIZE / 4) * NUM_STAGES) + + lane_id; + __int64_t* s_gate64_ptr = smem_compute_ptr; + __int64_t* s_up64_ptr = smem_compute_ptr + GROUP_SIZE / 4; + + int tokens_lower, tokens_upper; + + token_bounds(total_tokens, blockIdx.x, tokens_lower, + tokens_upper); + + Idx_t expert_id{}, expert_offset{}, next_expert_offset{}; + int token_id = tokens_lower; + int32_t t_load{}; + + if (token_id < tokens_upper) { + expert_id = warp_expert_search(lane_id, E, s_expert_offsets, token_id); + expert_offset = s_expert_offsets[expert_id]; + next_expert_offset = s_expert_offsets[expert_id + 1]; + } else { + // This thread block has no work to do. return; } - // We do calculations here, using constexpr wherever possible. - const Idx_t base_i = e * stride_i_e + NUM_WARPS * g * GROUP_SIZE * stride_i_h; - const Idx_t base_ys = e * stride_ys_e + NUM_WARPS * g * stride_ys_g; - const Idx_t base_yq = - e * stride_yq_e + NUM_WARPS * g * GROUP_SIZE * stride_yq_h; - Idx_t gate_off_128 = (base_i / static_cast(8u)); - auto input_128_ptr = reinterpret_cast(_input); - auto gate_128_ptr = input_128_ptr + gate_off_128 + (tid % HALF_THREAD_COUNT) + - stride_i_t_128 * n_tokens_lower; - auto up_128_ptr = gate_128_ptr + (H * stride_i_h) / 8u; - auto y_s_ptr = - _y_s + base_ys + warp_id * stride_ys_g + n_tokens_lower * stride_ys_t; - auto y_q_ptr = _y_q + base_yq + warp_id * GROUP_SIZE + - stride_yq_t * n_tokens_lower + 4 * lane_id; - int32_t t_load = n_tokens_lower, load_stage_id = 0; - auto s_buff_gate_load_128 = s_buff_128 + (tid % HALF_THREAD_COUNT); - auto s_buff_up_load_128 = s_buff_gate_load_128 + S_NUM_128 / 2u; - int32_t stage_offset{}; - - static constexpr int32_t LOAD_STAGE_SIZE = (NUM_WARPS * WARP_SIZE / 2); - static constexpr int32_t LOAD_STAGE_MOD = - NUM_STAGES * (NUM_WARPS * WARP_SIZE / 2); - - // Two halves of all threads in a block conduct global loads for gate and up, - // repsectively. + int t_load_bound = H / (GROUP_SIZE * NUM_WARPS); + + Idx_t base_i = ((expert_id * stride_i_e) / 8) + + (token_id - expert_offset) * stride_i_t / 8; + const Idx_t gate_warp_offset = + warp_id * ((stride_i_h * H) / (8 * NUM_WARPS)) + (lane_id & 0b1111); + + const __int128_t* input_128_ptr = + reinterpret_cast(_input) + gate_warp_offset + + ((lane_id < 16) ? 0 : ((H * stride_i_h) / 8)); + __int128_t* load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i); + + auto token_offset = token_id - expert_offset; + auto load_and_advance_y_pred = [&] { - if (t_load < n_tokens_upper) { - auto s_gate_stage_128_staged_ptr = s_buff_gate_load_128 + stage_offset; - auto s_up_stage_128_staged_ptr = s_buff_up_load_128 + stage_offset; + if (t_load < t_load_bound) { + // Here we are simply continuing to load data + // from the current token. + auto smem_load_ptr_staged = smem_load_ptr + load_stage_offset; // It is very important that LOAD_STAGE_SIZE is constexpr to avoid // unnecessary ALU ops. - stage_offset += LOAD_STAGE_SIZE; - stage_offset %= LOAD_STAGE_MOD; + load_stage_offset += LOAD_STAGE_SIZE; + load_stage_offset %= LOAD_STAGE_MOD; - if (tid < HALF_THREAD_COUNT) { - cp_async4(s_gate_stage_128_staged_ptr, gate_128_ptr); - gate_128_ptr += stride_i_t_128; + cp_async4(smem_load_ptr_staged, load_ptr); + load_ptr += GROUP_SIZE / 8; + ++t_load; + } else if (token_id + 1 < tokens_upper) { + // We loaded everything from the current token, let's move on + // to the next one, and we checked that we have more tokens to load. + ++token_id; + t_load = 0; + if (token_id >= next_expert_offset) { + // We need to find the next expert. + do { + // This is a loop because it's possible + // that some experts are assigned 0 tokens. + // NOTE: We are guaranteed that there's at least + // one more token left so we don't have to check for + // expert_id bounds. + ++expert_id; + // This skips 1 memory read. + expert_offset = next_expert_offset; + next_expert_offset = s_expert_offsets[expert_id + 1]; + } while (next_expert_offset == expert_offset); + + base_i = expert_id * (stride_i_e / 8); + token_offset = 0; + load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i); } else { - cp_async4(s_up_stage_128_staged_ptr, up_128_ptr); - up_128_ptr += stride_i_t_128; + // We remain within the same expert, so just + // move by H/4 __int128_t (2 * H/8). + base_i += stride_yq_t / 4; + token_offset++; } + + load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i); + + auto smem_load_ptr_staged = smem_load_ptr + load_stage_offset; + + // It is very important that LOAD_STAGE_SIZE is constexpr to avoid + // unnecessary ALU ops. + load_stage_offset += LOAD_STAGE_SIZE; + load_stage_offset %= LOAD_STAGE_MOD; + + cp_async4(smem_load_ptr_staged, load_ptr); + load_ptr += GROUP_SIZE / 8; ++t_load; - ++load_stage_id; } // We fence even if there is nothing to load to simplify pipelining. cp_async_fence(); }; + // We need to warm-up the pipeline. #pragma unroll for (int i = 0; i < NUM_STAGES - 1; i++) { load_and_advance_y_pred(); } - __int64_t* s_gate_ptr = reinterpret_cast<__int64_t*>( - s_buff_compute_32 + warp_id * (GROUP_SIZE / 2)) + - lane_id; - __int64_t* s_up_ptr = s_gate_ptr + S_NUM_64 / 2; - - static constexpr int32_t STAGE_SIZE = (GROUP_SIZE * NUM_WARPS) / 4u; - static constexpr int32_t STAGE_MOD = STAGE_SIZE * NUM_STAGES; - - int32_t compute_pipeline_offset_64 = 0; - - for (int32_t t = n_tokens_lower; t < n_tokens_upper; ++t) { - __nv_bfloat162 results_bf162[2]; + __nv_fp8x4_e4m3* y_q_base_ptr = + reinterpret_cast<__nv_fp8x4_e4m3*>(_y_q) + lane_id; + auto y_scale_base_ptr = _y_s + warp_position_scales * stride_ys_g; - cp_async_wait(); - __syncthreads(); + for (auto j = tokens_lower; j < tokens_upper; j++) { + const Idx_t base_ys = expert_id * stride_ys_e; + auto y_s_ptr = y_scale_base_ptr + base_ys + token_offset * stride_ys_t; + __nv_fp8x4_e4m3* y_q_ptr = + y_q_base_ptr + (expert_id * stride_yq_e + token_offset * stride_yq_t + + warp_position_yq * stride_yq_h) / + 4; + const int COMPUTE_LIMIT = H / (GROUP_SIZE * NUM_WARPS); - // We double-buffer pipelined loads so that the next load will - // concurrently run with compute without overwrites. - load_and_advance_y_pred(); + for (int i = 0; i < COMPUTE_LIMIT; i++) { + cp_async_wait(); + __syncthreads(); + load_and_advance_y_pred(); - auto s_gate_compute_64 = s_gate_ptr + compute_pipeline_offset_64; - auto s_up_compute_64 = s_up_ptr + compute_pipeline_offset_64; + __int64_t* gate64_ptr = s_gate64_ptr + compute_pipeline_offset_64; + __int64_t* up64_ptr = s_up64_ptr + compute_pipeline_offset_64; - // STAGE_SIZE must also be constexpr! - compute_pipeline_offset_64 += STAGE_SIZE; - compute_pipeline_offset_64 %= STAGE_MOD; + // COMPUTE_STAGE_SIZE/MOD must also be constexpr! + compute_pipeline_offset_64 += COMPUTE_STAGE_SIZE; + compute_pipeline_offset_64 %= COMPUTE_STAGE_MOD; - // Each thread loads (gate/up) 2X 4X bfloat16 values into registers. - __int64_t gate64 = *s_gate_compute_64; - __nv_bfloat162* s_gate_compute_32 = - reinterpret_cast<__nv_bfloat162*>(&gate64); + __int64_t gate64 = *gate64_ptr; + __int64_t up64 = *up64_ptr; - __int64_t up64 = *s_up_compute_64; - __nv_bfloat162* s_up_compute_32 = reinterpret_cast<__nv_bfloat162*>(&up64); + // Compute + __nv_bfloat162 res[2]; + __nv_bfloat162* s_up_comp = reinterpret_cast<__nv_bfloat162*>(&up64); + __nv_bfloat162* s_gate_comp = reinterpret_cast<__nv_bfloat162*>(&gate64); #pragma unroll - for (int i = 0; i < 2; i++) { - // For silu, we make sure that div is emitted. - float2 gate = silu2(__bfloat1622float2(s_gate_compute_32[i])); - results_bf162[i] = __float22bfloat162_rn(gate); - } - - #pragma unroll - for (int i = 0; i < 2; i++) { - results_bf162[i] = __hmul2(results_bf162[i], s_up_compute_32[i]); - } + for (int32_t k = 0; k < 2; ++k) { + __nv_bfloat162 gate = silu2_v2(__bfloat1622float2(s_gate_comp[k])); + res[k] = __hmul2(gate, s_up_comp[k]); + } - auto _y_max2 = - __hmax2(__habs2(results_bf162[0]), __habs2(results_bf162[1])); + auto _y_max2 = __hmax2(__habs2(res[0]), __habs2(res[1])); - __nv_bfloat16 y_max_bf16 = __hmax(EPS, __hmax(_y_max2.x, _y_max2.y)); + _y_max2.x = __hmax(__hmax(_y_max2.x, _y_max2.y), EPS); - // An entire group is assigned to a single warp, so a simple warp reduce - // is used. - __nv_bfloat16 y_s = warp_max(y_max_bf16) / fp8_max; + __nv_bfloat16 y_s = __hmul(warp_max(_y_max2.x), fp8_inv); - if constexpr (USE_UE8M0) { - y_s = hexp2(hceil(hlog2(y_s))); - } + if constexpr (USE_UE8M0) { + y_s = hexp2(hceil(hlog2(y_s))); + } - auto inv_y = __float2bfloat16_rn(1.f) / y_s; + __nv_bfloat16 inv_y = __hdiv(one_bf16, y_s); - auto y_s2 = make_bfloat162(inv_y, inv_y); + auto y_s2 = make_bfloat162(inv_y, inv_y); #pragma unroll - for (int32_t i = 0; i < 2; ++i) { - results_bf162[i] = - clip(__hmul2(results_bf162[i], y_s2), __bfloat162bfloat162(fp8_min), - __bfloat162bfloat162(fp8_max)); - } + for (int32_t k = 0; k < 2; ++k) { + res[k] = clip(__hmul2(res[k], y_s2), __bfloat162bfloat162(fp8_min), + __bfloat162bfloat162(fp8_max)); + } - auto fp8x4 = __nv_fp8x4_e4m3(results_bf162[0], results_bf162[1]); - *reinterpret_cast<__nv_fp8x4_e4m3*>(y_q_ptr) = fp8x4; - y_q_ptr += stride_yq_t; + *y_q_ptr = __nv_fp8x4_e4m3(res[0], res[1]); + y_q_ptr += WARP_SIZE * stride_yq_h; - if (lane_id == 0) { - *y_s_ptr = y_s; - y_s_ptr += stride_ys_t; + if (!lane_id) { + *y_s_ptr = y_s; + y_s_ptr += stride_ys_g; + } } } -} #endif +} } // namespace vllm @@ -475,14 +568,14 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d] LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } -void silu_mul_fp8_quant_deep_gemm_cuda( - const at::Tensor& input, // (E, T, 2*H) - const at::Tensor& counts, // (E) - at::Tensor& y_q, // (E, T, H) [OUT] - at::Tensor& y_s, // (E, T, H//group_size) [OUT] - int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens) { +void persistent_masked_m_silu_mul_quant( + const at::Tensor& input, // (E, T, 2*H) + const at::Tensor& tokens_per_expert, // (E) + at::Tensor& y_q, // (E, T, H) [OUT] + at::Tensor& y_s, // (E, T, H//group_size) [OUT] + bool use_ue8m0) { #ifndef USE_ROCM - // This kernel relies heavily on cp.async and fp8 support. + // This kernel currently only supports H % 128 == 0 and assumes a // fixed GROUP_SIZE of 128. TORCH_CHECK(input.dtype() == torch::kBFloat16); @@ -491,10 +584,6 @@ void silu_mul_fp8_quant_deep_gemm_cuda( TORCH_CHECK(y_s.dtype() == torch::kFloat32); TORCH_CHECK(input.size(-1) % 256 == 0); - // Check that num_parallel_tokens is of power of 2 and between 1 and 64. - TORCH_CHECK(1 <= num_parallel_tokens && num_parallel_tokens <= 64); - TORCH_CHECK(!(num_parallel_tokens & (num_parallel_tokens - 1))); - using Idx_t = int64_t; Idx_t E = input.size(0); @@ -510,81 +599,54 @@ void silu_mul_fp8_quant_deep_gemm_cuda( Idx_t stride_ys_t = y_s.stride(1); Idx_t stride_ys_g = y_s.stride(2); - Idx_t stride_counts_e = counts.stride(0); + Idx_t stride_counts_e = tokens_per_expert.stride(0); static constexpr int GROUP_SIZE = 128; - #define KERNEL_FN \ - if (use_ue8m0) { \ - vllm::silu_mul_fp8_quant_deep_gemm_kernel \ - <<>>( \ - reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \ - (fp8_t*)y_q.data_ptr(), y_s.data_ptr(), \ - reinterpret_cast(counts.data_ptr()), H, G, \ - stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \ - stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \ - stride_counts_e); \ - } else { \ - vllm::silu_mul_fp8_quant_deep_gemm_kernel \ - <<>>( \ - reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \ - (fp8_t*)y_q.data_ptr(), y_s.data_ptr(), \ - reinterpret_cast(counts.data_ptr()), H, G, \ - stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \ - stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \ - stride_counts_e); \ - } + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - #define KERNEL_CALL_H \ - if (H % (4 * GROUP_SIZE) == 0) { \ - static constexpr int NUM_WARPS = 4; \ - populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \ - KERNEL_FN \ - } else { \ - static constexpr int NUM_WARPS = 1; \ - populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \ - KERNEL_FN \ + #define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \ + static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \ + int sms = SILU_V2_BLOCK_COUNT; \ + static constexpr int max_shared_mem_bytes = \ + GROUP_SIZE * 2 * STAGES * NUM_WARPS * 2; \ + dim3 grid(sms), block(THREAD_COUNT); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + VLLM_DISPATCH_FP8_TYPES( \ + y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \ + vllm::silu_mul_fp8_quant_deep_gemm_kernel< \ + BLOCK_COUNT, max_shared_mem_bytes, fp8_t, THREAD_COUNT, Idx_t, \ + USE_UE8M0, GROUP_SIZE, STAGES> \ + <<>>( \ + reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \ + (fp8_t*)y_q.data_ptr(), y_s.data_ptr(), \ + reinterpret_cast(tokens_per_expert.data_ptr()), E, \ + T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \ + stride_yq_t, stride_yq_h, stride_ys_e, stride_ys_t, \ + stride_ys_g, stride_counts_e); \ + }); + + static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32; + + if (!use_ue8m0) { + if (H >= 4096) { + static constexpr int NUM_STAGES = 4; + static constexpr int THREAD_COUNT = 256; + KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES); + } else { + static constexpr int THREAD_COUNT = 32; + KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2); } - - #define KERNEL_CALL_TOP_LEVEL \ - if (num_parallel_tokens == 1) { \ - static constexpr int NUM_PARALLEL_TOKENS = 1; \ - KERNEL_CALL_H \ - } else if (num_parallel_tokens == 2) { \ - static constexpr int NUM_PARALLEL_TOKENS = 2; \ - KERNEL_CALL_H \ - } else if (num_parallel_tokens == 4) { \ - static constexpr int NUM_PARALLEL_TOKENS = 4; \ - KERNEL_CALL_H \ - } else if (num_parallel_tokens == 8) { \ - static constexpr int NUM_PARALLEL_TOKENS = 8; \ - KERNEL_CALL_H \ - } else if (num_parallel_tokens == 16) { \ - static constexpr int NUM_PARALLEL_TOKENS = 16; \ - KERNEL_CALL_H \ - } else if (num_parallel_tokens == 32) { \ - static constexpr int NUM_PARALLEL_TOKENS = 32; \ - KERNEL_CALL_H \ - } else if (num_parallel_tokens == 64) { \ - static constexpr int NUM_PARALLEL_TOKENS = 64; \ - KERNEL_CALL_H \ + } else { + if (H >= 4096) { + static constexpr int NUM_STAGES = 4; + static constexpr int THREAD_COUNT = 256; + KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES); + } else { + static constexpr int THREAD_COUNT = 32; + KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2); } - - Idx_t G; - dim3 block, grid; - auto populate_launch_params = [&](int num_warps, int _num_parallel_tokens) { - G = H / Idx_t(group_size * num_warps); - grid = dim3(E * G, _num_parallel_tokens); - block = dim3(num_warps * WARP_SIZE); - }; - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - VLLM_DISPATCH_FP8_TYPES(y_q.scalar_type(), - "silu_mul_fp8_quant_deep_gemm_kernel", - [&] { KERNEL_CALL_TOP_LEVEL }); + } #endif } diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 95aa92e25b30..92d6c2f402a2 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -145,7 +145,11 @@ void rms_norm_dynamic_per_token_quant( if (scale_ub.has_value()) { TORCH_CHECK(out.dtype() == kFp8Type); } + TORCH_CHECK(weight.dtype() == input.dtype()); TORCH_CHECK(scales.dtype() == torch::kFloat32); + if (residual) { + TORCH_CHECK(residual->scalar_type() == input.scalar_type()); + } VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] { diff --git a/csrc/quantization/fused_kernels/quant_conversions.cuh b/csrc/quantization/fused_kernels/quant_conversions.cuh index 4e6118e52e8d..2b1eb1d568e4 100644 --- a/csrc/quantization/fused_kernels/quant_conversions.cuh +++ b/csrc/quantization/fused_kernels/quant_conversions.cuh @@ -6,7 +6,7 @@ #include "quantization/vectorization.cuh" // TODO(luka/varun):refactor common.cuh to use this file instead -#include "quantization/fp8/common.cuh" +#include "quantization/w8a8/fp8/common.cuh" namespace vllm { diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 43b245530e95..8869d7cd521b 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -185,7 +185,7 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)(const half*, const uint32_t*, const uint32_t*, const half*, half*, const int, const int, const int, const int, - const int*); + const bool, const int*); template __global__ void gemm_half_q_half_gptq_4bit_kernel( @@ -193,12 +193,15 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( const uint32_t* __restrict__ b_gptq_qzeros, const half* __restrict__ b_gptq_scales, half* __restrict__ c, const int size_m, const int size_n, const int size_k, const int groups, - const int* __restrict__ b_q_perm) { + const bool use_v2_format, const int* __restrict__ b_q_perm) { MatrixView_half a_(a, size_m, size_k); MatrixView_half_rw c_(c, size_m, size_n); MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + auto t = threadIdx.x; // Block @@ -256,10 +259,10 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( half2 y1y16[4][2]; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_f(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]); // Column result float block_c[m_count][4] = {}; @@ -272,10 +275,10 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( nextgroup += groupsize; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_f(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]); } #pragma unroll @@ -329,12 +332,15 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel( const uint32_t* __restrict__ b_gptq_qzeros, const half* __restrict__ b_gptq_scales, half* __restrict__ c, const int size_m, const int size_n, const int size_k, const int groups, - const int* __restrict__ b_q_perm) { + const bool use_v2_format, const int* __restrict__ b_q_perm) { MatrixView_half a_(a, size_m, size_k); MatrixView_half_rw c_(c, size_m, size_n); MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + auto t = threadIdx.x; // Block @@ -409,10 +415,10 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel( int4 load_int4 = *b_ptr4; half2 dq[4][8]; - dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); - dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); - dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); - dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + zero_offset); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + zero_offset); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + zero_offset); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + zero_offset); #pragma unroll for (int m = 0; m < m_count; m++) { @@ -448,12 +454,15 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel( const uint32_t* __restrict__ b_gptq_qzeros, const half* __restrict__ b_gptq_scales, half* __restrict__ c, const int size_m, const int size_n, const int size_k, const int groups, - const int* __restrict__ b_q_perm) { + const bool use_v2_format, const int* __restrict__ b_q_perm) { MatrixView_half a_(a, size_m, size_k); MatrixView_half_rw c_(c, size_m, size_n); MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + auto t = threadIdx.x; // Block @@ -534,13 +543,13 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel( half2 dq[4][16]; dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], - size_n, zeros[0] + 1); + size_n, zeros[0] + zero_offset); dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], - size_n, zeros[1] + 1); + size_n, zeros[1] + zero_offset); dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], - size_n, zeros[2] + 1); + size_n, zeros[2] + zero_offset); dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], - size_n, zeros[3] + 1); + size_n, zeros[3] + zero_offset); #pragma unroll for (int m = 0; m < m_count; m++) { @@ -574,12 +583,15 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel( const uint32_t* __restrict__ b_gptq_qzeros, const half* __restrict__ b_gptq_scales, half* __restrict__ c, const int size_m, const int size_n, const int size_k, const int groups, - const int* __restrict__ b_q_perm) { + const bool use_v2_format, const int* __restrict__ b_q_perm) { MatrixView_half a_(a, size_m, size_k); MatrixView_half_rw c_(c, size_m, size_n); MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + auto t = threadIdx.x; // Block @@ -658,13 +670,13 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel( half2 dq[4][4]; dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, - zeros[0] + 1); + zeros[0] + zero_offset); dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, - zeros[1] + 1); + zeros[1] + zero_offset); dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, - zeros[2] + 1); + zeros[2] + zero_offset); dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, - zeros[3] + 1); + zeros[3] + zero_offset); for (int m = 0; m < m_count; m++) { block_c[m][0] = @@ -730,7 +742,8 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, const half* b_gptq_scales, const int* b_q_perm, half* c, int size_m, int size_n, int size_k, - int m_count, int groups, int bit) { + int m_count, int groups, bool use_v2_format, + int bit) { dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; @@ -743,20 +756,23 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight, pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - kernel<<>>(a, b_q_weight, b_gptq_qzeros, - b_gptq_scales, c, size_m, size_n, - size_k, groups, b_q_perm); + kernel<<>>( + a, b_q_weight, b_gptq_qzeros, b_gptq_scales, c, size_m, size_n, size_k, + groups, use_v2_format, b_q_perm); } __global__ void reconstruct_exllama_8bit_kernel( const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, - const int groups, half* __restrict__ b) { + const int groups, const bool use_v2_format, half* __restrict__ b) { MatrixView_half_rw b_(b, size_k, size_n); MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; @@ -812,13 +828,13 @@ __global__ void reconstruct_exllama_8bit_kernel( half2 dq[4][4]; dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, - zeros[0] + 1); + zeros[0] + zero_offset); dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, - zeros[1] + 1); + zeros[1] + zero_offset); dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, - zeros[2] + 1); + zeros[2] + zero_offset); dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, - zeros[3] + 1); + zeros[3] + zero_offset); // half* dqh = (half*)dq; if (b_q_perm) { @@ -849,11 +865,14 @@ __global__ void reconstruct_exllama_4bit_kernel( const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, - const int groups, half* __restrict__ b) { + const int groups, const bool use_v2_format, half* __restrict__ b) { MatrixView_half_rw b_(b, size_k, size_n); MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; @@ -888,10 +907,10 @@ __global__ void reconstruct_exllama_4bit_kernel( half2 y1y16[4][2]; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]); __syncthreads(); @@ -904,10 +923,10 @@ __global__ void reconstruct_exllama_4bit_kernel( nextgroup += groupsize; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]); } for (int p = 0; p < 4; p++) { @@ -954,11 +973,14 @@ __global__ void reconstruct_exllama_3bit_kernel( const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, - const int groups, half* __restrict__ b) { + const int groups, const bool use_v2_format, half* __restrict__ b) { MatrixView_half_rw b_(b, size_k, size_n); MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; @@ -1016,13 +1038,13 @@ __global__ void reconstruct_exllama_3bit_kernel( half2 dq[4][16]; dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], - size_n, zeros[0] + 1); + size_n, zeros[0] + zero_offset); dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], - size_n, zeros[1] + 1); + size_n, zeros[1] + zero_offset); dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], - size_n, zeros[2] + 1); + size_n, zeros[2] + zero_offset); dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], - size_n, zeros[3] + 1); + size_n, zeros[3] + zero_offset); if (b_q_perm) { for (int j = 0; j < 16; j++) { @@ -1052,11 +1074,14 @@ __global__ void reconstruct_exllama_2bit_kernel( const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, - const int groups, half* __restrict__ b) { + const int groups, const bool use_v2_format, half* __restrict__ b) { MatrixView_half_rw b_(b, size_k, size_n); MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; @@ -1108,10 +1133,10 @@ __global__ void reconstruct_exllama_2bit_kernel( int4 load_int4 = *b_ptr4; half2 dq[4][8]; - dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); - dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); - dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); - dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + zero_offset); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + zero_offset); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + zero_offset); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + zero_offset); b_ptr += size_n; // half* dqh = (half*)dq; @@ -1143,7 +1168,7 @@ void reconstruct_exllama(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, const half* b_gptq_scales, const int* b_q_perm, half* out, int height, int width, int groups, - int bit) { + bool use_v2_format, int bit) { dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; @@ -1162,14 +1187,14 @@ void reconstruct_exllama(const uint32_t* b_q_weight, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); reconstruct_exllama_kernel<<>>( b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups, - out); + use_v2_format, out); } __global__ void gemm_half_q_half_alt_4bit_kernel( const half2* __restrict__ vec, const uint32_t* __restrict__ mat, half* __restrict__ mul, const half* __restrict__ scales, const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, - int batch, int height, int width) { + int batch, int height, int width, bool use_v2_format) { int zero_width = width / 8; int vec_height = height * 4; const int blockwidth2 = BLOCK_KN_SIZE / 2; @@ -1179,6 +1204,9 @@ __global__ void gemm_half_q_half_alt_4bit_kernel( int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; if (threadIdx.x < h_end) { for (int m = 0; m < b_end; ++m) { @@ -1223,10 +1251,11 @@ __global__ void gemm_half_q_half_alt_4bit_kernel( half2 zero = __halves2half2( __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - - 1)), - __hmul(scale_f2, - __int2half_rn( - -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))); + zero_offset)), + __hmul( + scale_f2, + __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - + zero_offset))); scales_tmp[tmp_k] = scale; zeros_tmp[tmp_k] = zero; } @@ -1268,7 +1297,7 @@ __global__ void gemm_half_q_half_alt_8bit_kernel( const half2* __restrict__ vec, const uint32_t* __restrict__ mat, half* __restrict__ mul, const half* __restrict__ scales, const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, - int batch, int height, int width) { + int batch, int height, int width, bool use_v2_format) { int zero_width = width / 4; int vec_height = height * 2; const int blockwidth2 = BLOCK_KN_SIZE / 2; @@ -1278,6 +1307,9 @@ __global__ void gemm_half_q_half_alt_8bit_kernel( int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2; auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; if (threadIdx.x < h_end) { for (int m = 0; m < b_end; ++m) { @@ -1312,12 +1344,13 @@ __global__ void gemm_half_q_half_alt_8bit_kernel( half scale_f2 = scales[g2 * width + w]; half2 scale = __halves2half2(scale_f, scale_f2); half2 zero = __halves2half2( - __hmul(scale_f, - __int2half_rn( - -((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)), - __hmul(scale_f2, - __int2half_rn( - -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1))); + __hmul(scale_f, __int2half_rn( + -((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - + zero_offset)), + __hmul( + scale_f2, + __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - + zero_offset))); scales_tmp[tmp_k] = scale; zeros_tmp[tmp_k] = zero; } @@ -1355,7 +1388,7 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, const half* b_gptq_scales, const int* b_g_idx, half* c, int size_m, int size_n, int size_k, - int bit) { + bool use_v2_format, int bit) { dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; @@ -1372,17 +1405,15 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); kernel<<>>( (const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx, - size_m, size_k / 32 * bit, size_n); + size_m, size_k / 32 * bit, size_n, use_v2_format); } template -__global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w, - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int* __restrict__ g_idx, - const int height, const int width, - const int group, - half* __restrict__ out) { +__global__ void reconstruct_gptq_kernel( + const uint32_t* __restrict__ w, const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx, + const int height, const int width, const int group, + const bool use_v2_format, half* __restrict__ out) { // Start of block auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; @@ -1395,6 +1426,9 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w, MatrixView_half w_scales_(w_scales, group, width); T w_zeros_(w_zeros, group, width); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + uint32_t w_read = w[blockIdx.y * width + column]; half* out_ptr = out_.item_ptr(row, column); @@ -1402,7 +1436,7 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w, for (int s = 0; s < 32; s += bit) { int group = g_idx[row + s / bit]; half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; + uint32_t w_zero = w_zeros_.item(group, column) + zero_offset; half w_item = __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), w_scale); @@ -1415,7 +1449,7 @@ __global__ void reconstruct_gptq_3bit_kernel( const uint32_t* __restrict__ w, const half* __restrict__ w_scales, const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx, const int height, const int width, const int group, - half* __restrict__ out) { + const bool use_v2_format, half* __restrict__ out) { // Start of block auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; auto row = blockIdx.y * 32; @@ -1427,6 +1461,9 @@ __global__ void reconstruct_gptq_3bit_kernel( MatrixView_half w_scales_(w_scales, group, width); MatrixView_q3_row w_zeros_(w_zeros, group, width); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + uint32_t w1 = w[(blockIdx.y * 3) * width + column]; uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column]; uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column]; @@ -1436,7 +1473,7 @@ __global__ void reconstruct_gptq_3bit_kernel( for (int i = 0; i < 32; i += 1) { int group = g_idx[row + i]; half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; + uint32_t w_zero = w_zeros_.item(group, column) + zero_offset; int w_item; if (i == 10) { w_item = (w1 >> 30) | ((w2 << 2) & 0x4); @@ -1456,7 +1493,8 @@ __global__ void reconstruct_gptq_3bit_kernel( void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, const half* b_gptq_scales, const int* b_g_idx, half* out, - int height, int width, int groups, int bit) { + int height, int width, int groups, bool use_v2_format, + int bit) { dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; @@ -1476,7 +1514,7 @@ void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); kernel<<>>(b_q_weight, b_gptq_scales, b_gptq_qzeros, b_g_idx, height, - width, groups, out); + width, groups, use_v2_format, out); } void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, @@ -1484,7 +1522,8 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, const uint32_t* b_gptq_qzeros, const half* b_gptq_scales, const int* b_g_idx, half* c, half* temp_dq, int size_m, int size_n, - int size_k, int groups, bool use_exllama, int bit) { + int size_k, int groups, bool use_exllama, + bool use_v2_format, int bit) { bool use_reconstruct; if (use_exllama) { use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || @@ -1498,10 +1537,10 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, // Reconstruct FP16 matrix, then cuBLAS if (use_exllama) { reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - temp_dq, size_k, size_n, groups, bit); + temp_dq, size_k, size_n, groups, use_v2_format, bit); } else { reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - temp_dq, size_k, size_n, groups, bit); + temp_dq, size_k, size_n, groups, use_v2_format, bit); } const half alpha = __float2half(1.0f); @@ -1517,18 +1556,18 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, if (max_chunks) { gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, c, last_chunk, size_n, size_k, - BLOCK_M_SIZE_MAX, groups, bit); + BLOCK_M_SIZE_MAX, groups, use_v2_format, bit); } if (last_chunk_size) { - gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, - b_gptq_qzeros, b_gptq_scales, b_g_idx, - c + last_chunk * size_n, last_chunk_size, - size_n, size_k, last_chunk_size, groups, bit); + gemm_half_q_half_cuda_part( + a + last_chunk * size_k, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, c + last_chunk * size_n, last_chunk_size, size_n, size_k, + last_chunk_size, groups, use_v2_format, bit); } } else { gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - c, size_m, size_n, size_k, bit); + c, size_m, size_n, size_k, use_v2_format, bit); } } @@ -1815,7 +1854,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height, torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, int64_t bit) { + bool use_exllama, bool use_v2_format, int64_t bit) { const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); @@ -1833,7 +1872,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, c.size(1), // n a.size(1), // k b_gptq_qzeros.size(0), // group number - use_exllama, bit); + use_exllama, use_v2_format, bit); return c; } diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index d29a199c5d32..8bd17ba69cec 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -9,7 +9,6 @@ from copy import deepcopy from dataclasses import dataclass, fields from functools import reduce -from typing import Optional, Union import jinja2 from vllm_cutlass_library_extension import ( @@ -259,7 +258,7 @@ class ScheduleConfig: @dataclass(frozen=True) class TypeConfig: a: DataType - b: Union[DataType, VLLMDataType] + b: DataType | VLLMDataType b_group_scale: DataType b_group_zeropoint: DataType b_channel_scale: DataType @@ -280,7 +279,7 @@ class PrepackTypeConfig: class ImplConfig: types: TypeConfig schedules: list[ScheduleConfig] - heuristic: list[tuple[Optional[str], ScheduleConfig]] + heuristic: list[tuple[str | None, ScheduleConfig]] def generate_sch_sig(schedule_config: ScheduleConfig) -> str: diff --git a/csrc/quantization/cutlass_w8a8/Epilogues.md b/csrc/quantization/w8a8/cutlass/Epilogues.md similarity index 100% rename from csrc/quantization/cutlass_w8a8/Epilogues.md rename to csrc/quantization/w8a8/cutlass/Epilogues.md diff --git a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh b/csrc/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh rename to csrc/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu b/csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu rename to csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu diff --git a/csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh b/csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh rename to csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh rename to csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu rename to csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu rename to csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/moe/moe_data.cu rename to csrc/quantization/w8a8/cutlass/moe/moe_data.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm75_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm75_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm80_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm80_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_fp8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_fp8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_int8_dispatch.cuh similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh rename to csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_int8_dispatch.cuh diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu similarity index 100% rename from csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu rename to csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/w8a8/fp8/amd/quant_utils.cuh similarity index 99% rename from csrc/quantization/fp8/amd/quant_utils.cuh rename to csrc/quantization/w8a8/fp8/amd/quant_utils.cuh index e51a4e14e518..81f5cb83f3e1 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/w8a8/fp8/amd/quant_utils.cuh @@ -5,7 +5,7 @@ #include #include -#include "../../../attention/attention_dtypes.h" +#include "../../../../attention/attention_dtypes.h" namespace vllm { #ifdef USE_ROCM diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/w8a8/fp8/common.cu similarity index 99% rename from csrc/quantization/fp8/common.cu rename to csrc/quantization/w8a8/fp8/common.cu index 45d6d5082ce4..7a822fb8fb8a 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/w8a8/fp8/common.cu @@ -1,7 +1,7 @@ #include "common.cuh" #include "dispatch_utils.h" -#include "../../cub_helpers.h" -#include "../vectorization_utils.cuh" +#include "cub_helpers.h" +#include "quantization/vectorization_utils.cuh" #include #include diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/w8a8/fp8/common.cuh similarity index 100% rename from csrc/quantization/fp8/common.cuh rename to csrc/quantization/w8a8/fp8/common.cuh diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh similarity index 99% rename from csrc/quantization/fp8/nvidia/quant_utils.cuh rename to csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh index 5361a8b1a598..421e8092474b 100644 --- a/csrc/quantization/fp8/nvidia/quant_utils.cuh +++ b/csrc/quantization/w8a8/fp8/nvidia/quant_utils.cuh @@ -1,6 +1,6 @@ #pragma once -#include "../../../attention/attention_dtypes.h" +#include "../../../../attention/attention_dtypes.h" #include #include #include diff --git a/csrc/quantization/fp8/per_token_group_quant.cu b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu similarity index 98% rename from csrc/quantization/fp8/per_token_group_quant.cu rename to csrc/quantization/w8a8/fp8/per_token_group_quant.cu index 91d489fdef86..e3ab0676b254 100644 --- a/csrc/quantization/fp8/per_token_group_quant.cu +++ b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu @@ -1,6 +1,6 @@ #include -#include "../per_token_group_quant_8bit.h" +#include "quantization/w8a8/per_token_group_quant_8bit.h" #include @@ -8,9 +8,9 @@ #include -#include "../vectorization.cuh" -#include "../vectorization_utils.cuh" -#include "../../dispatch_utils.h" +#include "quantization/vectorization.cuh" +#include "quantization/vectorization_utils.cuh" +#include "dispatch_utils.h" __device__ __forceinline__ float GroupReduceMax(float val) { unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff; @@ -212,4 +212,4 @@ void per_token_group_quant_fp8(const torch::Tensor& input, double fp8_max, bool scale_ue8m0) { per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0); -} +} \ No newline at end of file diff --git a/csrc/quantization/w8a8/int8/per_token_group_quant.cu b/csrc/quantization/w8a8/int8/per_token_group_quant.cu new file mode 100644 index 000000000000..9d808a176f53 --- /dev/null +++ b/csrc/quantization/w8a8/int8/per_token_group_quant.cu @@ -0,0 +1,12 @@ +#include +#include + +#include "quantization/w8a8/per_token_group_quant_8bit.h" + +void per_token_group_quant_int8(const torch::Tensor& input, + torch::Tensor& output_q, + torch::Tensor& output_s, int64_t group_size, + double eps, double int8_min, double int8_max) { + per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, + int8_min, int8_max); +} \ No newline at end of file diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/w8a8/int8/scaled_quant.cu similarity index 94% rename from csrc/quantization/compressed_tensors/int8_quant_kernels.cu rename to csrc/quantization/w8a8/int8/scaled_quant.cu index bcfde9fbcbbe..7fe9e96bfb01 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/w8a8/int8/scaled_quant.cu @@ -1,15 +1,11 @@ #include #include -#ifndef USE_ROCM - #include "../per_token_group_quant_8bit.h" -#endif - #include -#include "../../cub_helpers.h" -#include "../../dispatch_utils.h" -#include "../vectorization_utils.cuh" +#include "dispatch_utils.h" +#include "quantization/vectorization_utils.cuh" +#include "cub_helpers.h" static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM @@ -25,7 +21,6 @@ static inline __device__ int8_t float_to_int8_rn(float x) { float dst = std::nearbyint(x); // saturate - // See https://github.com/pytorch/pytorch/issues/127666 // See https://github.com/llvm/llvm-project/issues/95183 // hip-clang std::clamp __glibcxx_assert_fail host function when building on @@ -84,7 +79,6 @@ static inline __device__ int8_t int32_to_int8(int32_t x) { static_cast(std::numeric_limits::max()); // saturate - // See https://github.com/pytorch/pytorch/issues/127666 // See https://github.com/llvm/llvm-project/issues/95183 // hip-clang std::clamp __glibcxx_assert_fail host function when building on @@ -176,7 +170,6 @@ __global__ void dynamic_scaled_int8_quant_kernel( float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax; - // 2. quantize vectorize_with_alignment<16>( row_in, row_out, hidden_size, tid, stride, [=] __device__(int8_t& dst, const scalar_t& src) { @@ -194,7 +187,6 @@ struct MinMax { __host__ __device__ explicit MinMax(float v) : min(v), max(v) {} - // add a value to the MinMax __host__ __device__ MinMax& operator+=(float v) { min = fminf(min, v); max = fmaxf(max, v); @@ -228,7 +220,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( const scalar_t* row_in = input + token_idx * hidden_size; int8_t* row_out = output + token_idx * hidden_size; - // 1. calculate min & max MinMax thread_mm; vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) { @@ -261,7 +252,6 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( const float inv_s = 1.f / scale_sh; const azp_t azp = azp_sh; - // 2. quantize vectorize_with_alignment<16>( row_in, row_out, hidden_size, tid, stride, [=] __device__(int8_t& dst, const scalar_t& src) { @@ -332,14 +322,4 @@ void dynamic_scaled_int8_quant( hidden_size); } }); -} - -#ifndef USE_ROCM -void per_token_group_quant_int8(const torch::Tensor& input, - torch::Tensor& output_q, - torch::Tensor& output_s, int64_t group_size, - double eps, double int8_min, double int8_max) { - per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, - int8_min, int8_max); -} -#endif +} \ No newline at end of file diff --git a/csrc/quantization/per_token_group_quant_8bit.h b/csrc/quantization/w8a8/per_token_group_quant_8bit.h similarity index 84% rename from csrc/quantization/per_token_group_quant_8bit.h rename to csrc/quantization/w8a8/per_token_group_quant_8bit.h index 537b61bc4303..25d4ecd1131a 100644 --- a/csrc/quantization/per_token_group_quant_8bit.h +++ b/csrc/quantization/w8a8/per_token_group_quant_8bit.h @@ -1,7 +1,6 @@ #pragma once #include -// TODO(wentao): refactor the folder to 8bit, then includes fp8 and int8 folders // 8-bit per-token-group quantization helper used by both FP8 and INT8 void per_token_group_quant_8bit(const torch::Tensor& input, torch::Tensor& output_q, diff --git a/csrc/quickreduce/quick_reduce.h b/csrc/quickreduce/quick_reduce.h index 4fe4c44be7eb..4cc35300bf87 100644 --- a/csrc/quickreduce/quick_reduce.h +++ b/csrc/quickreduce/quick_reduce.h @@ -22,13 +22,14 @@ template __global__ __quickreduce_launch_bounds_two_shot__ static void allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks, int rank, uint8_t** dbuffer_list, - uint32_t data_offset, uint32_t flag_color) { + uint32_t data_offset, uint32_t flag_color, + int64_t data_size_per_phase) { int block = blockIdx.x; int grid = gridDim.x; while (block < num_blocks) { AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, - flag_color); + flag_color, data_size_per_phase); block += grid; flag_color++; } @@ -41,21 +42,21 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks, hipLaunchKernelGGL((allreduce_prototype_twoshot), \ dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ num_blocks, rank, dbuffer_list, data_offset, \ - flag_color); \ + flag_color, this->kMaxProblemSize); \ } else if (world_size == 4) { \ using LineCodec = __codec; \ using AllReduceKernel = AllReduceTwoshot; \ hipLaunchKernelGGL((allreduce_prototype_twoshot), \ dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ num_blocks, rank, dbuffer_list, data_offset, \ - flag_color); \ + flag_color, this->kMaxProblemSize); \ } else if (world_size == 8) { \ using LineCodec = __codec; \ using AllReduceKernel = AllReduceTwoshot; \ hipLaunchKernelGGL((allreduce_prototype_twoshot), \ dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ num_blocks, rank, dbuffer_list, data_offset, \ - flag_color); \ + flag_color, this->kMaxProblemSize); \ } enum QuickReduceQuantLevel { diff --git a/csrc/quickreduce/quick_reduce_impl.cuh b/csrc/quickreduce/quick_reduce_impl.cuh index 17816c552d25..38dc9938fc8a 100644 --- a/csrc/quickreduce/quick_reduce_impl.cuh +++ b/csrc/quickreduce/quick_reduce_impl.cuh @@ -553,13 +553,12 @@ struct AllReduceTwoshot { int const rank, // rank index uint8_t** __restrict__ buffer_list, // communication buffers uint32_t const data_offset, // offset to start of the data buffer - uint32_t flag_color) { + uint32_t flag_color, int64_t data_size_per_phase) { // Topology int thread = threadIdx.x + threadIdx.y * kWavefront; uint8_t* rank_buffer = buffer_list[rank]; Codec codec(thread, rank); int block_id = blockIdx.x; - int grid_size = gridDim.x; // -------------------------------------------------------- // Read input into registers int32x4_t tA[kAtoms]; @@ -588,12 +587,10 @@ struct AllReduceTwoshot { // rank responsible for this segment. uint32_t comm_data0_offset = data_offset + block_id * Codec::kTransmittedTileSize; - uint32_t comm_data1_offset = - grid_size * Codec::kTransmittedTileSize + comm_data0_offset; + uint32_t comm_data1_offset = data_size_per_phase + comm_data0_offset; uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t)); - uint32_t comm_flags1_offset = - grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset; + uint32_t comm_flags1_offset = (data_offset / 2) + comm_flags0_offset; for (int r = 0; r < kWorldSize; r++) { int32x4_t* send_buffer = diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index df3208a120f1..a339c5641bb4 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -23,7 +23,7 @@ #include #include "../attention/dtype_fp8.cuh" -#include "../quantization/fp8/amd/quant_utils.cuh" +#include "../quantization/w8a8/fp8/amd/quant_utils.cuh" // ROCm 6.2 compatibility: map OCP fp8 types to FNUZ variants if OCP is absent #if !defined(HIP_FP8_TYPE_OCP) diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index e4600350d3ea..2ef579a1b753 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -11,7 +11,7 @@ #include "../cuda_compat.h" #include "dispatch_utils.h" -#include "quantization/fp8/common.cuh" +#include "quantization/w8a8/fp8/common.cuh" #if defined(__HIPCC__) && \ (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) diff --git a/csrc/sampler.cu b/csrc/sampler.cu index bc589d99d04b..410b8988f493 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -54,15 +54,10 @@ static inline __device__ uint16_t extractBinIdx(float x) { return 511 - (tmp.u16 >> 7); } -template -static __global__ void topKPerRow(const float* logits, const int* rowStarts, - const int* rowEnds, int* outIndices, - float* outLogits, int stride0, int stride1) { - // The number of bins in the histogram. - static constexpr int kNumBins = 512; - - // The top-k width. - static constexpr int kTopK = 2048; +template +__device__ void topKPerRowJob(const float* logits, const int rowStart, + const int rowEnd, const int rowIdx, + int* outIndices, int stride0, int stride1) { // The number of elements per thread for the final top-k sort. static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock; // The class to sort the elements during the final top-k sort. @@ -103,17 +98,11 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts, __shared__ int smemHistogram[kNumBins]; // Shared memory to store the selected indices. __shared__ int smemIndices[kTopK]; - // Shared memory to store the selected logits. - __shared__ float smemLogits[kTopK]; // Shared memory to store the threshold bin. __shared__ int smemThresholdBinIdx[1]; // Shared memory counter to register the candidates for the final phase. __shared__ int smemFinalDstIdx[1]; - // The row computed by this block. - int rowIdx = blockIdx.x; - // The range of logits within the row. - int rowStart = rowStarts[rowIdx], rowEnd = rowEnds[rowIdx]; // The length of the row. int rowLen = rowEnd - rowStart; @@ -124,13 +113,10 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts, rowIt += kNumThreadsPerBlock) { int idx = rowStart + rowIt; outIndices[rowIdx * kTopK + rowIt] = idx - rowStart; - outLogits[rowIdx * kTopK + rowIt] = - logits[rowIdx * stride0 + idx * stride1]; } for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK; rowIt += kNumThreadsPerBlock) { outIndices[rowIdx * kTopK + rowIt] = -1; - outLogits[rowIdx * kTopK + rowIt] = -FLT_MAX; } return; } @@ -201,7 +187,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts, uint16_t idx = extractBinIdx(logit); if (idx < thresholdBinIdx) { int dstIdx = atomicAdd(&smemHistogram[idx], 1); - smemLogits[dstIdx] = logit; smemIndices[dstIdx] = rowIt; } else if (idx == thresholdBinIdx) { int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1); @@ -250,7 +235,6 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts, int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x; int dstIdx = baseIdx + srcIdx; if (dstIdx < kTopK) { - smemLogits[dstIdx] = finalLogits[ii]; smemIndices[dstIdx] = finalIndices[ii]; } } @@ -258,31 +242,58 @@ static __global__ void topKPerRow(const float* logits, const int* rowStarts, // Make sure the data is in shared memory. __syncthreads(); - // The topK logits. - float topKLogits[kNumTopKItemsPerThread]; - // The topK indices. - int topKIndices[kNumTopKItemsPerThread]; - -// Load from shared memory. -#pragma unroll - for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) { - topKLogits[ii] = smemLogits[ii * kNumThreadsPerBlock + threadIdx.x]; - topKIndices[ii] = smemIndices[ii * kNumThreadsPerBlock + threadIdx.x]; - } - - // Sort the elements. - TopKSort(smemFinal.topKSort) - .SortDescendingBlockedToStriped(topKLogits, topKIndices); - // Store to global memory. #pragma unroll for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) { int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x; - outIndices[offset] = topKIndices[ii] - rowStart; - outLogits[offset] = topKLogits[ii]; + outIndices[offset] = + smemIndices[ii * kNumThreadsPerBlock + threadIdx.x] - rowStart; } } +template +static __global__ void topKPerRow(const float* logits, const int* rowStarts, + const int* rowEnds, int* outIndices, + int stride0, int stride1) { + // The number of bins in the histogram. + static constexpr int kNumBins = 512; + + // The top-k width. + static constexpr int kTopK = 2048; + + // The row computed by this block. + int rowIdx = blockIdx.x; + + // The range of logits within the row. + int rowStart = rowStarts[rowIdx]; + int rowEnd = rowEnds[rowIdx]; + + topKPerRowJob( + logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1); +} + +template +static __global__ void topKPerRowDecode(const float* logits, const int* seqLens, + int* outIndices, int stride0, + int stride1, int next_n) { + // The number of bins in the histogram. + static constexpr int kNumBins = 512; + + // The top-k width. + static constexpr int kTopK = 2048; + + // The row computed by this block. + int rowIdx = blockIdx.x; + + // The range of logits within the row. + int rowStart = 0; + int seq_len = seqLens[rowIdx / next_n]; + int rowEnd = seq_len - next_n + (rowIdx % next_n) + 1; + + topKPerRowJob( + logits, rowStart, rowEnd, rowIdx, outIndices, stride0, stride1); +} + } // namespace vllm void apply_repetition_penalties_( @@ -326,10 +337,23 @@ void apply_repetition_penalties_( }); } +void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, + const torch::Tensor& seqLens, torch::Tensor& indices, + int64_t numRows, int64_t stride0, int64_t stride1) { + // Compute the results on the device. + constexpr int kNumThreadsPerBlock = 512; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + vllm::topKPerRowDecode + <<>>( + logits.data_ptr(), seqLens.data_ptr(), + indices.data_ptr(), static_cast(stride0), + static_cast(stride1), static_cast(next_n)); +} + void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, const torch::Tensor& rowEnds, torch::Tensor& indices, - torch::Tensor& values, int64_t numRows, int64_t stride0, - int64_t stride1) { + int64_t numRows, int64_t stride0, int64_t stride1) { // Compute the results on the device. constexpr int kNumThreadsPerBlock = 512; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -338,6 +362,5 @@ void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts, <<>>( logits.data_ptr(), rowStarts.data_ptr(), rowEnds.data_ptr(), indices.data_ptr(), - values.data_ptr(), static_cast(stride0), - static_cast(stride1)); + static_cast(stride0), static_cast(stride1)); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index bef8cdc33f13..8f091a429fbe 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -33,11 +33,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #endif ops.def( - "silu_mul_fp8_quant_deep_gemm_cuda(Tensor input, Tensor counts, Tensor! " - "y_q, Tensor! y_s, int group_size, " - "bool use_ue8m0, int num_parallel_tokens) -> ()"); - ops.impl("silu_mul_fp8_quant_deep_gemm_cuda", torch::kCUDA, - &silu_mul_fp8_quant_deep_gemm_cuda); + "persistent_masked_m_silu_mul_quant(Tensor input, Tensor counts, Tensor! " + "y_q, Tensor! y_s," + "bool use_ue8m0) -> ()"); + ops.impl("persistent_masked_m_silu_mul_quant", torch::kCUDA, + &persistent_masked_m_silu_mul_quant); ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); @@ -175,12 +175,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); - // Polynomial Normalization. - ops.def( - "poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float " - "epsilon) -> ()"); - ops.impl("poly_norm", torch::kCUDA, &poly_norm); - // Apply repetition penalties to logits in-place ops.def( "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " @@ -191,10 +185,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Optimized top-k per row operation ops.def( "top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, " - "Tensor! indices, Tensor! values, int numRows, int stride0, " + "Tensor! indices, int numRows, int stride0, " "int stride1) -> ()"); ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row); + ops.def( + "top_k_per_row_decode(Tensor logits, int next_n, " + "Tensor seq_lens, Tensor! indices, int numRows, " + "int stride0, int stride1) -> ()"); + ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode); + // Layernorm-quant // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( @@ -557,7 +557,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // to prevent the meta function registry. ops.def( "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, " - "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) " + "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool " + "use_v2_format, int bit) " "-> Tensor", {stride_tag}); ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); diff --git a/docker/Dockerfile b/docker/Dockerfile index f9df931e73b1..eb1453126e6f 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,7 +5,7 @@ # docs/contributing/dockerfile/dockerfile.md and # docs/assets/contributing/dockerfile-stages-dependency.png -ARG CUDA_VERSION=12.8.1 +ARG CUDA_VERSION=12.9.1 ARG PYTHON_VERSION=3.12 # By parameterizing the base images, we allow third-party to use their own @@ -15,7 +15,7 @@ ARG PYTHON_VERSION=3.12 # Example: # docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 -# Important: We build with an old version of Ubuntu to maintain broad +# Important: We build with an old version of Ubuntu to maintain broad # compatibility with other Linux OSes. The main reason for this is that the # glibc version is baked into the distro, and binaries built with one glibc # version are not backwards compatible with OSes that use an earlier version. @@ -132,7 +132,9 @@ WORKDIR /workspace COPY requirements/common.txt requirements/common.txt COPY requirements/cuda.txt requirements/cuda.txt RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --python /opt/venv/bin/python3 -r requirements/cuda.txt \ + # TODO: remove apache-tvm-ffi once FlashInfer is fixed https://github.com/flashinfer-ai/flashinfer/issues/1962 + uv pip install --python /opt/venv/bin/python3 --pre apache-tvm-ffi==0.1.0b15 \ + && uv pip install --python /opt/venv/bin/python3 -r requirements/cuda.txt \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') # cuda arch list used by torch @@ -229,7 +231,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ # Check the size of the wheel if RUN_WHEEL_CHECK is true COPY .buildkite/check-wheel-size.py check-wheel-size.py # sync the default value with .buildkite/check-wheel-size.py -ARG VLLM_MAX_SIZE_MB=450 +ARG VLLM_MAX_SIZE_MB=500 ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB ARG RUN_WHEEL_CHECK=true RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \ @@ -273,6 +275,7 @@ WORKDIR /vllm-workspace ENV DEBIAN_FRONTEND=noninteractive ARG TARGETPLATFORM +# TODO (huydhn): There is no prebuilt gdrcopy package on 12.9 at the moment ARG GDRCOPY_CUDA_VERSION=12.8 # Keep in line with FINAL_BASE_IMAGE ARG GDRCOPY_OS_VERSION=Ubuntu22_04 @@ -353,78 +356,26 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # Install vllm wheel first, so that torch etc will be installed. RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system dist/*.whl --verbose \ + # TODO: remove apache-tvm-ffi once FlashInfer is fixed https://github.com/flashinfer-ai/flashinfer/issues/1962 + uv pip install --system --pre apache-tvm-ffi==0.1.0b15 \ + && uv pip install --system dist/*.whl --verbose \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') -# If we need to build FlashInfer wheel before its release: -# $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+ -# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0' -# $ git clone https://github.com/flashinfer-ai/flashinfer.git --recursive -# $ cd flashinfer -# $ git checkout v0.2.6.post1 -# $ python -m flashinfer.aot -# $ python -m build --no-isolation --wheel -# $ ls -la dist -# -rw-rw-r-- 1 mgoin mgoin 205M Jun 9 18:03 flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl -# $ # upload the wheel to a public location, e.g. https://wheels.vllm.ai/flashinfer/v0.2.6.post1/flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl - -# Install FlashInfer from source -ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" -# Keep this in sync with "flashinfer" extra in setup.py -ARG FLASHINFER_GIT_REF="v0.3.1" -# Flag to control whether to compile FlashInfer AOT kernels -# Set to "true" to enable AOT compilation: -# docker build --build-arg FLASHINFER_AOT_COMPILE=true ... -ARG FLASHINFER_AOT_COMPILE=false +# TODO (huydhn): Remove this once xformers is released for 2.9.0 RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' - . /etc/environment - git clone --depth 1 --recursive --shallow-submodules \ - --branch ${FLASHINFER_GIT_REF} \ - ${FLASHINFER_GIT_REPO} flashinfer - # Exclude CUDA arches for older versions (11.x and 12.0-12.7) - # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg. - if [[ "${CUDA_VERSION}" == 11.* ]]; then - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" - elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" - else - # CUDA 12.8+ supports 10.0a and 12.0 - FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" - fi - pushd flashinfer - if [[ "${CUDA_VERSION}" == 12.8.* ]] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then - # NOTE: To make new precompiled wheels, see tools/flashinfer-build.sh - echo "🏗️ Installing FlashInfer from pre-compiled wheel" - uv pip install --system https://wheels.vllm.ai/flashinfer-python/flashinfer_python-0.3.1-cp39-abi3-manylinux1_x86_64.whl \ - --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') - if [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then - # Download pre-compiled cubins - TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ - python3 -m flashinfer --download-cubin || echo "WARNING: Failed to download flashinfer cubins." - fi - elif [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then - echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}" - export FLASHINFER_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" - # HACK: We need these to run flashinfer.aot before installing flashinfer, get from the package in the future - uv pip install --system cuda-python==$(echo $CUDA_VERSION | cut -d. -f1,2) pynvml==$(echo $CUDA_VERSION | cut -d. -f1) nvidia-nvshmem-cu$(echo $CUDA_VERSION | cut -d. -f1) - # Build AOT kernels - TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ - python3 -m flashinfer.aot - # Install with no-build-isolation since we already built AOT kernels - TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ - uv pip install --system --no-build-isolation . \ - --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') - # Download pre-compiled cubins - TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ - python3 -m flashinfer --download-cubin || echo "WARNING: Failed to download flashinfer cubins." - else - echo "🏗️ Installing FlashInfer without AOT compilation in JIT mode" - uv pip install --system . \ - --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') - fi - popd - rm -rf flashinfer + . /etc/environment + export TORCH_CUDA_ARCH_LIST='7.5 8.0+PTX 9.0a' + uv pip install --system --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.32.post2" BASH + +# Install FlashInfer pre-compiled kernel cache and binaries +# https://docs.flashinfer.ai/installation.html +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system flashinfer-cubin==0.4.1 \ + && uv pip install --system flashinfer-jit-cache==0.4.1 \ + --extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ + && flashinfer show-config + COPY examples examples COPY benchmarks benchmarks COPY ./vllm/collect_env.py . @@ -483,6 +434,7 @@ ARG PYTHON_VERSION ARG PIP_INDEX_URL UV_INDEX_URL ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL +ARG PYTORCH_CUDA_INDEX_BASE_URL # This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out # Reference: https://github.com/astral-sh/uv/pull/1694 @@ -495,7 +447,8 @@ ENV UV_LINK_MODE=copy RUN --mount=type=cache,target=/root/.cache/uv \ CUDA_MAJOR="${CUDA_VERSION%%.*}"; \ if [ "$CUDA_MAJOR" -ge 12 ]; then \ - uv pip install --system -r requirements/dev.txt; \ + uv pip install --system -r requirements/dev.txt \ + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.'); \ fi # install development dependencies (for testing) @@ -542,7 +495,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ else \ BITSANDBYTES_VERSION="0.46.1"; \ fi; \ - uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3]>=0.14.0' + uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.14.0' ENV VLLM_USAGE_SOURCE production-docker-image diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index 2ed02ff9e3ac..f3fd1ee3e32b 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -13,7 +13,7 @@ # vllm-dev: used for development # # Build arguments: -# PYTHON_VERSION=3.12 (default)|3.11|3.10|3.9 +# PYTHON_VERSION=3.13|3.12 (default)|3.11|3.10 # VLLM_CPU_DISABLE_AVX512=false (default)|true # VLLM_CPU_AVX512BF16=false (default)|true # VLLM_CPU_AVX512VNNI=false (default)|true @@ -31,7 +31,7 @@ ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ --mount=type=cache,target=/var/lib/apt,sharing=locked \ apt-get update -y \ - && apt-get install -y --no-install-recommends ccache git curl wget ca-certificates \ + && apt-get install -y --no-install-recommends sudo ccache git curl wget ca-certificates \ gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 jq lsof \ && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 \ && curl -LsSf https://astral.sh/uv/install.sh | sh @@ -79,6 +79,9 @@ RUN echo 'ulimit -c 0' >> ~/.bashrc ######################### BUILD IMAGE ######################### FROM base AS vllm-build +ARG max_jobs=2 +ENV MAX_JOBS=${max_jobs} + ARG GIT_REPO_CHECK=0 # Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ... ARG VLLM_CPU_DISABLE_AVX512=0 @@ -104,16 +107,20 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=cache,target=/root/.cache/ccache \ --mount=type=cache,target=/workspace/vllm/.deps,sharing=locked \ --mount=type=bind,source=.git,target=.git \ - VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel + VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 ######################### TEST DEPS ######################### FROM base AS vllm-test-deps WORKDIR /workspace/vllm +# TODO: Update to 2.9.0 when there is a new build for intel_extension_for_pytorch for that version RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \ cp requirements/test.in requirements/cpu-test.in && \ sed -i '/mamba_ssm/d' requirements/cpu-test.in && \ + sed -i 's/^torch==.*/torch==2.8.0/g' requirements/cpu-test.in && \ + sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \ + sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \ uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu RUN --mount=type=cache,target=/root/.cache/uv \ diff --git a/docker/Dockerfile.nightly_torch b/docker/Dockerfile.nightly_torch index 6a9c3fa7dbed..6dfa56017838 100644 --- a/docker/Dockerfile.nightly_torch +++ b/docker/Dockerfile.nightly_torch @@ -246,7 +246,7 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2. # build flashinfer for torch nightly from source around 10 mins -# release version: v0.3.1 +# release version: v0.4.1 # todo(elainewy): cache flashinfer build result for faster build ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/ccache \ @@ -254,7 +254,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ echo "git clone flashinfer..." \ && git clone --recursive https://github.com/flashinfer-ai/flashinfer.git \ && cd flashinfer \ - && git checkout v0.3.1 \ + && git checkout v0.4.1\ && git submodule update --init --recursive \ && echo "finish git clone flashinfer..." \ && rm -rf build \ diff --git a/docker/Dockerfile.ppc64le b/docker/Dockerfile.ppc64le index 5eaef4ea980d..ad9eae94b83d 100644 --- a/docker/Dockerfile.ppc64le +++ b/docker/Dockerfile.ppc64le @@ -1,4 +1,4 @@ -ARG BASE_UBI_IMAGE_TAG=9.5-1741850109 +ARG BASE_UBI_IMAGE_TAG=9.6-1754584681 ############################################################### # Stage to build openblas @@ -7,7 +7,7 @@ ARG BASE_UBI_IMAGE_TAG=9.5-1741850109 FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS openblas-builder ARG MAX_JOBS -ARG OPENBLAS_VERSION=0.3.29 +ARG OPENBLAS_VERSION=0.3.30 RUN microdnf install -y dnf && dnf install -y gcc-toolset-13 make wget unzip \ && source /opt/rh/gcc-toolset-13/enable \ && wget https://github.com/OpenMathLib/OpenBLAS/releases/download/v$OPENBLAS_VERSION/OpenBLAS-$OPENBLAS_VERSION.zip \ @@ -38,7 +38,7 @@ RUN dnf install -y openjpeg2-devel lcms2-devel tcl-devel tk-devel fribidi-devel FROM centos-deps-builder AS base-builder ARG PYTHON_VERSION=3.12 -ARG OPENBLAS_VERSION=0.3.29 +ARG OPENBLAS_VERSION=0.3.30 # Set Environment Variables for venv, cargo & openblas ENV VIRTUAL_ENV=/opt/vllm @@ -61,7 +61,7 @@ RUN --mount=type=bind,from=openblas-builder,source=/OpenBLAS-$OPENBLAS_VERSION/, pkgconfig xsimd zeromq-devel kmod findutils protobuf* \ libtiff-devel libjpeg-devel zlib-devel freetype-devel libwebp-devel \ harfbuzz-devel libraqm-devel libimagequant-devel libxcb-devel \ - python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip \ + python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip clang-devel \ && dnf clean all \ && PREFIX=/usr/local make -C /openblas install \ && ln -sf /usr/lib64/libatomic.so.1 /usr/lib64/libatomic.so \ @@ -79,9 +79,9 @@ RUN --mount=type=bind,from=openblas-builder,source=/OpenBLAS-$OPENBLAS_VERSION/, FROM base-builder AS torch-builder ARG MAX_JOBS -ARG TORCH_VERSION=2.6.0 +ARG TORCH_VERSION=2.7.0 ARG _GLIBCXX_USE_CXX11_ABI=1 -ARG OPENBLAS_VERSION=0.3.29 +ARG OPENBLAS_VERSION=0.3.30 RUN --mount=type=cache,target=/root/.cache/uv \ source /opt/rh/gcc-toolset-13/enable && \ @@ -93,7 +93,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ MAX_JOBS=${MAX_JOBS:-$(nproc)} \ PYTORCH_BUILD_VERSION=${TORCH_VERSION} PYTORCH_BUILD_NUMBER=1 uv build --wheel --out-dir /torchwheels/ -ARG TORCHVISION_VERSION=0.21.0 +ARG TORCHVISION_VERSION=0.22.0 ARG TORCHVISION_USE_NVJPEG=0 ARG TORCHVISION_USE_FFMPEG=0 RUN --mount=type=cache,target=/root/.cache/uv \ @@ -104,7 +104,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ BUILD_VERSION=${TORCHVISION_VERSION} \ uv build --wheel --out-dir /torchwheels/ --no-build-isolation -ARG TORCHAUDIO_VERSION=2.6.0 +ARG TORCHAUDIO_VERSION=2.7.0 ARG BUILD_SOX=1 ARG BUILD_KALDI=1 ARG BUILD_RNNT=1 @@ -128,7 +128,7 @@ FROM base-builder AS arrow-builder ARG MAX_JOBS ARG PYARROW_PARALLEL -ARG PYARROW_VERSION=19.0.1 +ARG PYARROW_VERSION=21.0.0 RUN --mount=type=cache,target=/root/.cache/uv \ source /opt/rh/gcc-toolset-13/enable && \ git clone --recursive https://github.com/apache/arrow.git -b apache-arrow-${PYARROW_VERSION} && \ @@ -145,7 +145,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \ make install -j ${MAX_JOBS:-$(nproc)} && \ cd ../../python/ && \ uv pip install -v -r requirements-build.txt && uv pip install numpy==2.1.3 && \ - pip show numpy && ls -lrt /opt/vllm/lib/python3.12/site-packages/numpy && \ PYARROW_PARALLEL=${PYARROW_PARALLEL:-$(nproc)} \ python setup.py build_ext \ --build-type=release --bundle-arrow-cpp \ @@ -187,6 +186,23 @@ RUN git clone --recursive https://github.com/numactl/numactl.git -b v${NUMACTL_V && make -j ${MAX_JOBS:-$(nproc)} +############################################################### +# Stage to build numba +############################################################### + +FROM base-builder AS numba-builder + +ARG MAX_JOBS +ARG NUMBA_VERSION=0.61.2 + +# Clone all required dependencies +RUN dnf install ninja-build llvm15 llvm15-devel -y && source /opt/rh/gcc-toolset-13/enable && export PATH=$PATH:/usr/lib64/llvm15/bin && \ + git clone --recursive https://github.com/numba/numba.git -b ${NUMBA_VERSION} && \ + cd ./numba && \ + if ! grep '#include "dynamic_annotations.h"' numba/_dispatcher.cpp; then \ + sed -i '/#include "internal\/pycore_atomic.h"/i\#include "dynamic_annotations.h"' numba/_dispatcher.cpp; \ + fi && python -m build --wheel --installer=uv --outdir /numbawheels/ + ############################################################### # Stage to build vllm - this stage builds and installs # vllm, tensorizer and vllm-tgis-adapter and builds uv cache @@ -199,6 +215,7 @@ COPY --from=torch-builder /tmp/control /dev/null COPY --from=arrow-builder /tmp/control /dev/null COPY --from=cv-builder /tmp/control /dev/null COPY --from=numa-builder /tmp/control /dev/null +COPY --from=numba-builder /tmp/control /dev/null ARG VLLM_TARGET_DEVICE=cpu ARG GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=1 @@ -206,6 +223,8 @@ ARG GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=1 # this step installs vllm and populates uv cache # with all the transitive dependencies RUN --mount=type=cache,target=/root/.cache/uv \ + dnf install llvm15 llvm15-devel -y && \ + rpm -ivh --nodeps https://mirror.stream.centos.org/9-stream/CRB/ppc64le/os/Packages/protobuf-lite-devel-3.14.0-16.el9.ppc64le.rpm && \ source /opt/rh/gcc-toolset-13/enable && \ git clone https://github.com/huggingface/xet-core.git && cd xet-core/hf_xet/ && \ uv pip install maturin && \ @@ -215,15 +234,18 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \ --mount=type=bind,from=cv-builder,source=/opencvwheels/,target=/opencvwheels/,ro \ --mount=type=bind,from=numa-builder,source=/numactl/,target=/numactl/,rw \ + --mount=type=bind,from=numba-builder,source=/numbawheels/,target=/numbawheels/,ro \ --mount=type=bind,src=.,dst=/src/,rw \ source /opt/rh/gcc-toolset-13/enable && \ - uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl && \ + export PATH=$PATH:/usr/lib64/llvm15/bin && \ + uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /numbawheels/*.whl && \ sed -i -e 's/.*torch.*//g' /src/pyproject.toml /src/requirements/*.txt && \ - uv pip install pandas pythran pybind11 /hf_wheels/*.whl && \ + sed -i -e 's/.*sentencepiece.*//g' /src/pyproject.toml /src/requirements/*.txt && \ + uv pip install sentencepiece==0.2.0 pandas pythran nanobind pybind11 /hf_wheels/*.whl && \ make -C /numactl install && \ # sentencepiece.pc is in some pkgconfig inside uv cache export PKG_CONFIG_PATH=$(find / -type d -name "pkgconfig" 2>/dev/null | tr '\n' ':') && \ - uv pip install -r /src/requirements/common.txt -r /src/requirements/cpu.txt -r /src/requirements/build.txt --no-build-isolation && \ + nanobind_DIR=$(uv pip show nanobind | grep Location | sed 's/^Location: //;s/$/\/nanobind\/cmake/') && uv pip install -r /src/requirements/common.txt -r /src/requirements/cpu.txt -r /src/requirements/build.txt --no-build-isolation && \ cd /src/ && \ uv build --wheel --out-dir /vllmwheel/ --no-build-isolation && \ uv pip install /vllmwheel/*.whl @@ -250,7 +272,7 @@ RUN git clone --recursive https://github.com/Reference-LAPACK/lapack.git -b v${L FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS vllm-openai ARG PYTHON_VERSION=3.12 -ARG OPENBLAS_VERSION=0.3.29 +ARG OPENBLAS_VERSION=0.3.30 # Set Environment Variables for venv & openblas ENV VIRTUAL_ENV=/opt/vllm @@ -268,6 +290,7 @@ COPY --from=vllmcache-builder /tmp/control /dev/null COPY --from=numa-builder /tmp/control /dev/null COPY --from=lapack-builder /tmp/control /dev/null COPY --from=openblas-builder /tmp/control /dev/null +COPY --from=numba-builder /tmp/control /dev/null # install gcc-11, python, openblas, numactl, lapack RUN --mount=type=cache,target=/root/.cache/uv \ @@ -276,13 +299,13 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=openblas-builder,source=/OpenBLAS-$OPENBLAS_VERSION/,target=/openblas/,rw \ rpm -ivh https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm && \ microdnf install --nodocs -y \ - tar findutils openssl \ + libomp tar findutils openssl llvm15 llvm15-devel \ pkgconfig xsimd g++ gcc-fortran libsndfile \ libtiff libjpeg openjpeg2 zlib zeromq \ freetype lcms2 libwebp tcl tk utf8proc \ - harfbuzz fribidi libraqm libimagequant libxcb \ + harfbuzz fribidi libraqm libimagequant libxcb util-linux \ python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip \ - && microdnf clean all \ + && export PATH=$PATH:/usr/lib64/llvm15/bin && microdnf clean all \ && python${PYTHON_VERSION} -m venv ${VIRTUAL_ENV} \ && python -m pip install -U pip uv --no-cache \ && make -C /numactl install \ @@ -298,7 +321,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,from=cv-builder,source=/opencvwheels/,target=/opencvwheels/,ro \ --mount=type=bind,from=vllmcache-builder,source=/hf_wheels/,target=/hf_wheels/,ro \ --mount=type=bind,from=vllmcache-builder,source=/vllmwheel/,target=/vllmwheel/,ro \ - HOME=/root uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /hf_wheels/*.whl /vllmwheel/*.whl + --mount=type=bind,from=numba-builder,source=/numbawheels/,target=/numbawheels/,ro \ + export PKG_CONFIG_PATH=$(find / -type d -name "pkgconfig" 2>/dev/null | tr '\n' ':') && uv pip install sentencepiece==0.2.0 && \ + HOME=/root uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /numbawheels/*.whl /hf_wheels/*.whl /vllmwheel/*.whl + COPY ./ /workspace/vllm WORKDIR /workspace/vllm @@ -314,4 +340,4 @@ WORKDIR /workspace/ RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks -ENTRYPOINT ["vllm", "serve"] +ENTRYPOINT ["vllm", "serve"] \ No newline at end of file diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index c8900212e5a1..adb0879f20d4 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -12,7 +12,7 @@ ENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH:-${PYTORCH_ROCM_ARCH}} RUN apt-get update -q -y && apt-get install -q -y \ sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev \ apt-transport-https ca-certificates wget curl -# Remove sccache +# Remove sccache RUN python3 -m pip install --upgrade pip RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)" ARG COMMON_WORKDIR diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index 873c2fbcd4d3..5479eebaf795 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -1,13 +1,13 @@ ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.0-complete -ARG TRITON_BRANCH="f9e5bf54" +ARG TRITON_BRANCH="57c693b6" ARG TRITON_REPO="https://github.com/ROCm/triton.git" -ARG PYTORCH_BRANCH="b2fb6885" +ARG PYTORCH_BRANCH="1c57644d" ARG PYTORCH_VISION_BRANCH="v0.23.0" ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="0e60e394" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="2ab9f4cd" +ARG AITER_BRANCH="eef23c7f" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index ffc3abd38965..49ea39cad512 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -69,4 +69,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \ # install development dependencies (for testing) RUN python3 -m pip install -e tests/vllm_test_utils + +# install nixl from source code +RUN python3 /workspace/vllm/tools/install_nixl_from_source_ubuntu.py +ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages/.nixl.mesonpy.libs/plugins/" + ENTRYPOINT ["vllm", "serve"] diff --git a/docs/api/README.md b/docs/api/README.md index 86e310f567dd..d3a141f32730 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -20,8 +20,6 @@ API documentation for vLLM's configuration classes. - [vllm.config.CompilationConfig][] - [vllm.config.VllmConfig][] -[](){ #offline-inference-api } - ## Offline Inference LLM Class. @@ -45,18 +43,14 @@ Engine classes for offline and online inference. Inference parameters for vLLM APIs. -[](){ #sampling-params } - - [vllm.SamplingParams][] - [vllm.PoolingParams][] -[](){ #multi-modality } - ## Multi-Modality vLLM provides experimental support for multi-modal models through the [vllm.multimodal][] package. -Multi-modal inputs can be passed alongside text and token prompts to [supported models][supported-mm-models] +Multi-modal inputs can be passed alongside text and token prompts to [supported models](../models/supported_models.md#list-of-multimodal-language-models) via the `multi_modal_data` field in [vllm.inputs.PromptType][]. Looking to add your own multi-modal model? Please follow the instructions listed [here](../contributing/model/multimodal.md). diff --git a/docs/assets/contributing/dockerfile-stages-dependency.png b/docs/assets/contributing/dockerfile-stages-dependency.png index 0838bfa37fe6..f8c104ba1425 100644 Binary files a/docs/assets/contributing/dockerfile-stages-dependency.png and b/docs/assets/contributing/dockerfile-stages-dependency.png differ diff --git a/docs/community/sponsors.md b/docs/community/sponsors.md index 6ad3a6625266..8abb07caaab6 100644 --- a/docs/community/sponsors.md +++ b/docs/community/sponsors.md @@ -34,6 +34,7 @@ Compute Resources: - Trainy - UC Berkeley - UC San Diego +- Volcengine Slack Sponsor: Anyscale diff --git a/docs/configuration/README.md b/docs/configuration/README.md index 6a8fbc79f4af..85ae642ba6dd 100644 --- a/docs/configuration/README.md +++ b/docs/configuration/README.md @@ -4,6 +4,6 @@ This section lists the most common options for running vLLM. There are three main levels of configuration, from highest priority to lowest priority: -- [Request parameters][completions-api] and [input arguments][sampling-params] +- [Request parameters](../serving/openai_compatible_server.md#completions-api) and [input arguments](../api/README.md#inference-parameters) - [Engine arguments](./engine_args.md) - [Environment variables](./env_vars.md) diff --git a/docs/configuration/conserving_memory.md b/docs/configuration/conserving_memory.md index efda9c8e019e..5ce43c798405 100644 --- a/docs/configuration/conserving_memory.md +++ b/docs/configuration/conserving_memory.md @@ -11,8 +11,7 @@ The following code splits the model across 2 GPUs. ```python from vllm import LLM -llm = LLM(model="ibm-granite/granite-3.1-8b-instruct", - tensor_parallel_size=2) +llm = LLM(model="ibm-granite/granite-3.1-8b-instruct", tensor_parallel_size=2) ``` !!! warning @@ -24,7 +23,7 @@ llm = LLM(model="ibm-granite/granite-3.1-8b-instruct", !!! note With tensor parallelism enabled, each process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). - You can convert the model checkpoint to a sharded checkpoint using . The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism. + You can convert the model checkpoint to a sharded checkpoint using [examples/offline_inference/save_sharded_state.py](../../examples/offline_inference/save_sharded_state.py). The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism. ## Quantization @@ -43,9 +42,7 @@ and the maximum batch size (`max_num_seqs` option). ```python from vllm import LLM -llm = LLM(model="adept/fuyu-8b", - max_model_len=2048, - max_num_seqs=2) +llm = LLM(model="adept/fuyu-8b", max_model_len=2048, max_num_seqs=2) ``` ## Reduce CUDA Graphs @@ -61,12 +58,12 @@ You can adjust `compilation_config` to achieve a better balance between inferenc ```python from vllm import LLM - from vllm.config import CompilationConfig, CompilationLevel + from vllm.config import CompilationConfig, CompilationMode llm = LLM( model="meta-llama/Llama-3.1-8B-Instruct", compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, # By default, it goes up to max_num_seqs cudagraph_capture_sizes=[1, 2, 4, 8, 16], ), @@ -78,8 +75,7 @@ You can disable graph capturing completely via the `enforce_eager` flag: ```python from vllm import LLM -llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", - enforce_eager=True) +llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", enforce_eager=True) ``` ## Adjust cache size @@ -97,8 +93,10 @@ You can allow a smaller number of multi-modal items per prompt to reduce the mem from vllm import LLM # Accept up to 3 images and 1 video per prompt -llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", - limit_mm_per_prompt={"image": 3, "video": 1}) +llm = LLM( + model="Qwen/Qwen2.5-VL-3B-Instruct", + limit_mm_per_prompt={"image": 3, "video": 1}, +) ``` You can go a step further and disable unused modalities completely by setting its limit to zero. @@ -108,8 +106,10 @@ For example, if your application only accepts image input, there is no need to a from vllm import LLM # Accept any number of images but no videos -llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", - limit_mm_per_prompt={"video": 0}) +llm = LLM( + model="Qwen/Qwen2.5-VL-3B-Instruct", + limit_mm_per_prompt={"video": 0}, +) ``` You can even run a multi-modal model for text-only inference: @@ -118,10 +118,52 @@ You can even run a multi-modal model for text-only inference: from vllm import LLM # Don't accept images. Just text. -llm = LLM(model="google/gemma-3-27b-it", - limit_mm_per_prompt={"image": 0}) +llm = LLM( + model="google/gemma-3-27b-it", + limit_mm_per_prompt={"image": 0}, +) ``` +### Configurable options + +`limit_mm_per_prompt` also accepts configurable options per modality. In the configurable form, you still specify `count`, and you may optionally provide size hints that control how vLLM profiles and reserves memory for your multi‑modal inputs. This helps you tune memory for the actual media you expect, instead of the model’s absolute maxima. + +Configurable options by modality: + +- `image`: `{"count": int, "width": int, "height": int}` +- `video`: `{"count": int, "num_frames": int, "width": int, "height": int}` +- `audio`: `{"count": int, "length": int}` + +Details could be found in [`ImageDummyOptions`][vllm.config.multimodal.ImageDummyOptions], [`VideoDummyOptions`][vllm.config.multimodal.VideoDummyOptions], and [`AudioDummyOptions`][vllm.config.multimodal.AudioDummyOptions]. + +Examples: + +```python +from vllm import LLM + +# Up to 5 images per prompt, profile with 512x512. +# Up to 1 video per prompt, profile with 32 frames at 640x640. +llm = LLM( + model="Qwen/Qwen2.5-VL-3B-Instruct", + limit_mm_per_prompt={ + "image": {"count": 5, "width": 512, "height": 512}, + "video": {"count": 1, "num_frames": 32, "width": 640, "height": 640}, + }, +) +``` + +For backward compatibility, passing an integer works as before and is interpreted as `{"count": }`. For example: + +- `limit_mm_per_prompt={"image": 5}` is equivalent to `limit_mm_per_prompt={"image": {"count": 5}}` +- You can mix formats: `limit_mm_per_prompt={"image": 5, "video": {"count": 1, "num_frames": 32, "width": 640, "height": 640}}` + +!!! note + - The size hints affect memory profiling only. They shape the dummy inputs used to compute reserved activation sizes. They do not change how inputs are actually processed at inference time. + - If a hint exceeds what the model can accept, vLLM clamps it to the model's effective maximum and may log a warning. + +!!! warning + These size hints currently only affect activation memory profiling. Encoder cache size is determined by the actual inputs at runtime and is not limited by these hints. + ## Multi-modal processor arguments For certain models, you can adjust the multi-modal processor arguments to @@ -133,14 +175,14 @@ Here are some examples: from vllm import LLM # Available for Qwen2-VL series models -llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", - mm_processor_kwargs={ - "max_pixels": 768 * 768, # Default is 1280 * 28 * 28 - }) +llm = LLM( + model="Qwen/Qwen2.5-VL-3B-Instruct", + mm_processor_kwargs={"max_pixels": 768 * 768}, # Default is 1280 * 28 * 28 +) # Available for InternVL series models -llm = LLM(model="OpenGVLab/InternVL2-2B", - mm_processor_kwargs={ - "max_dynamic_patch": 4, # Default is 12 - }) +llm = LLM( + model="OpenGVLab/InternVL2-2B", + mm_processor_kwargs={"max_dynamic_patch": 4}, # Default is 12 +) ``` diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 5c74610ebd29..b0d390d7e1cb 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -27,8 +27,6 @@ You can monitor the number of preemption requests through Prometheus metrics exp In vLLM V1, the default preemption mode is `RECOMPUTE` rather than `SWAP`, as recomputation has lower overhead in the V1 architecture. -[](){ #chunked-prefill } - ## Chunked Prefill Chunked prefill allows vLLM to process large prefills in smaller chunks and batch them together with decode requests. This feature helps improve both throughput and latency by better balancing compute-bound (prefill) and memory-bound (decode) operations. @@ -100,7 +98,7 @@ from vllm import LLM llm = LLM( model="meta-llama/Llama-3.3-70B-Instruct, tensor_parallel_size=4, - pipeline_parallel_size=2 + pipeline_parallel_size=2, ) ``` @@ -174,14 +172,14 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u Known supported models (with corresponding benchmarks): -- dots_ocr () -- GLM-4.1V or above () -- InternVL () -- Kimi-VL () -- Llama4 () -- MiniCPM-V-2.5 or above (, ) -- Qwen2-VL or above (, , ) -- Step3 () +- dots_ocr () +- GLM-4.1V or above () +- InternVL () +- Kimi-VL () +- Llama4 () +- MiniCPM-V-2.5 or above (, ) +- Qwen2-VL or above (, , ) +- Step3 () ## Input Processing @@ -257,18 +255,24 @@ Examples: ```python # Use a larger cache -llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", - mm_processor_cache_gb=8) +llm = LLM( + model="Qwen/Qwen2.5-VL-3B-Instruct", + mm_processor_cache_gb=8, +) # Use a shared-memory based IPC cache -llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", - tensor_parallel_size=2, - mm_processor_cache_type="shm", - mm_processor_cache_gb=8) +llm = LLM( + model="Qwen/Qwen2.5-VL-3B-Instruct", + tensor_parallel_size=2, + mm_processor_cache_type="shm", + mm_processor_cache_gb=8, +) # Disable the cache -llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", - mm_processor_cache_gb=0) +llm = LLM( + model="Qwen/Qwen2.5-VL-3B-Instruct", + mm_processor_cache_gb=0, +) ``` ### Cache Placement diff --git a/docs/configuration/tpu.md b/docs/configuration/tpu.md index e456077e0495..25d371e627b7 100644 --- a/docs/configuration/tpu.md +++ b/docs/configuration/tpu.md @@ -96,7 +96,7 @@ Although it’s common to do this with GPUs, don't try to fragment 2 or 8 differ ### Tune your workloads -Although we try to have great default configs, we strongly recommend you check out the [vLLM auto-tuner](gh-file:benchmarks/auto_tune/README.md) to optimize your workloads for your use case. +Although we try to have great default configs, we strongly recommend you check out the [vLLM auto-tuner](../../benchmarks/auto_tune/README.md) to optimize your workloads for your use case. ### Future Topics We'll Cover diff --git a/docs/contributing/README.md b/docs/contributing/README.md index b0a95b3b3d3a..368c0dc84b3a 100644 --- a/docs/contributing/README.md +++ b/docs/contributing/README.md @@ -22,7 +22,7 @@ Unsure on where to start? Check out the following links for tasks to work on: ## License -See . +See [LICENSE](../../LICENSE). ## Developing @@ -54,7 +54,7 @@ For more details about installing from source and installing for other hardware, For an optimized workflow when iterating on C++/CUDA kernels, see the [Incremental Compilation Workflow](./incremental_build.md) for recommendations. !!! tip - vLLM is compatible with Python versions 3.9 to 3.12. However, vLLM's default [Dockerfile](gh-file:docker/Dockerfile) ships with Python 3.12 and tests in CI (except `mypy`) are run with Python 3.12. + vLLM is compatible with Python versions 3.10 to 3.13. However, vLLM's default [Dockerfile](../../docker/Dockerfile) ships with Python 3.12 and tests in CI (except `mypy`) are run with Python 3.12. Therefore, we recommend developing with Python 3.12 to minimise the chance of your local environment clashing with our CI environment. @@ -83,12 +83,12 @@ vLLM's `pre-commit` hooks will now run automatically every time you commit. ```bash pre-commit run --hook-stage manual markdownlint - pre-commit run --hook-stage manual mypy-3.9 + pre-commit run --hook-stage manual mypy-3.10 ``` ### Documentation -MkDocs is a fast, simple and downright gorgeous static site generator that's geared towards building project documentation. Documentation source files are written in Markdown, and configured with a single YAML configuration file, . +MkDocs is a fast, simple and downright gorgeous static site generator that's geared towards building project documentation. Documentation source files are written in Markdown, and configured with a single YAML configuration file, [mkdocs.yaml](../../mkdocs.yaml). Get started with: @@ -152,7 +152,7 @@ pytest -s -v tests/test_logger.py If you encounter a bug or have a feature request, please [search existing issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue) first to see if it has already been reported. If not, please [file a new issue](https://github.com/vllm-project/vllm/issues/new/choose), providing as much relevant information as possible. !!! important - If you discover a security vulnerability, please follow the instructions [here](gh-file:SECURITY.md#reporting-a-vulnerability). + If you discover a security vulnerability, please follow the instructions [here](../../SECURITY.md). ## Pull Requests & Code Reviews @@ -162,7 +162,7 @@ code quality and improve the efficiency of the review process. ### DCO and Signed-off-by -When contributing changes to this project, you must agree to the . +When contributing changes to this project, you must agree to the [DCO](../../DCO). Commits must include a `Signed-off-by:` header which certifies agreement with the terms of the DCO. diff --git a/docs/contributing/benchmarks.md b/docs/contributing/benchmarks.md index 6b1eabf3d67f..89524ed3bc63 100644 --- a/docs/contributing/benchmarks.md +++ b/docs/contributing/benchmarks.md @@ -6,9 +6,10 @@ toc_depth: 4 vLLM provides comprehensive benchmarking tools for performance testing and evaluation: -- **[Benchmark CLI]**: `vllm bench` CLI tools and specialized benchmark scripts for interactive performance testing -- **[Performance benchmarks][performance-benchmarks]**: Automated CI benchmarks for development -- **[Nightly benchmarks][nightly-benchmarks]**: Comparative benchmarks against alternatives +- **[Benchmark CLI](#benchmark-cli)**: `vllm bench` CLI tools and specialized benchmark scripts for interactive performance testing +- **[Parameter sweeps](#parameter-sweeps)**: Automate `vllm bench` runs for multiple configurations +- **[Performance benchmarks](#performance-benchmarks)**: Automated CI benchmarks for development +- **[Nightly benchmarks](#nightly-benchmarks)**: Comparative benchmarks against alternatives [Benchmark CLI]: #benchmark-cli @@ -29,12 +30,13 @@ th { | Dataset | Online | Offline | Data Path | |---------|--------|---------|-----------| | ShareGPT | ✅ | ✅ | `wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json` | -| ShareGPT4V (Image) | ✅ | ✅ | `wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/blob/main/sharegpt4v_instruct_gpt4-vision_cap100k.json`
Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:
`wget http://images.cocodataset.org/zips/train2017.zip` | +| ShareGPT4V (Image) | ✅ | ✅ | `wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/resolve/main/sharegpt4v_instruct_gpt4-vision_cap100k.json`
Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:
`wget http://images.cocodataset.org/zips/train2017.zip` | | ShareGPT4Video (Video) | ✅ | ✅ | `git clone https://huggingface.co/datasets/ShareGPT4Video/ShareGPT4Video` | | BurstGPT | ✅ | ✅ | `wget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv` | | Sonnet (deprecated) | ✅ | ✅ | Local file: `benchmarks/sonnet.txt` | | Random | ✅ | ✅ | `synthetic` | | RandomMultiModal (Image/Video) | 🟡 | 🚧 | `synthetic` | +| RandomForReranking | ✅ | ✅ | `synthetic` | | Prefix Repetition | ✅ | ✅ | `synthetic` | | HuggingFace-VisionArena | ✅ | ✅ | `lmarena-ai/VisionArena-Chat` | | HuggingFace-MMVU | ✅ | ✅ | `yale-nlp/MMVU` | @@ -713,7 +715,7 @@ Generate synthetic image inputs alongside random text prompts to stress-test vis Notes: -- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`. +- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`. - Video sampling is not yet implemented. Start the server (example): @@ -821,7 +823,7 @@ you should set `--endpoint /v1/embeddings` to use the Embeddings API. The backen - CLIP: `--backend openai-embeddings-clip` - VLM2Vec: `--backend openai-embeddings-vlm2vec` -For other models, please add your own implementation inside to match the expected instruction format. +For other models, please add your own implementation inside [vllm/benchmarks/lib/endpoint_request_func.py](../../vllm/benchmarks/lib/endpoint_request_func.py) to match the expected instruction format. You can use any text or multi-modal dataset to benchmark the model, as long as the model supports it. For example, you can use ShareGPT and VisionArena to benchmark vision-language embeddings. @@ -878,7 +880,207 @@ vllm bench serve \ -[](){ #performance-benchmarks } +#### Reranker Benchmark + +Benchmark the performance of rerank requests in vLLM. + +
+Show more + +Unlike generative models which use Completions API or Chat Completions API, +you should set `--backend vllm-rerank` and `--endpoint /v1/rerank` to use the Reranker API. + +For reranking, the only supported dataset is `--dataset-name random-rerank` + +Start the server: + +```bash +vllm serve BAAI/bge-reranker-v2-m3 +``` + +Run the benchmark: + +```bash +vllm bench serve \ + --model BAAI/bge-reranker-v2-m3 \ + --backend vllm-rerank \ + --endpoint /v1/rerank \ + --dataset-name random-rerank \ + --tokenizer BAAI/bge-reranker-v2-m3 \ + --random-input-len 512 \ + --num-prompts 10 \ + --random-batch-size 5 +``` + +For reranker models, this will create `num_prompts / random_batch_size` requests with +`random_batch_size` "documents" where each one has close to `random_input_len` tokens. +In the example above, this results in 2 rerank requests with 5 "documents" each where +each document has close to 512 tokens. + +Please note that the `/v1/rerank` is also supported by embedding models. So if you're running +with an embedding model, also set `--no_reranker`. Because in this case the query is +treated as a individual prompt by the server, here we send `random_batch_size - 1` documents +to account for the extra prompt which is the query. The token accounting to report the +throughput numbers correctly is also adjusted. + +
+ +## Parameter Sweeps + +### Online Benchmark + +[`vllm/benchmarks/sweep/serve.py`](../../vllm/benchmarks/sweep/serve.py) automatically starts `vllm serve` and runs `vllm bench serve` to evaluate vLLM over multiple configurations. + +Follow these steps to run the script: + +1. Construct the base command to `vllm serve`, and pass it to the `--serve-cmd` option. +2. Construct the base command to `vllm bench serve`, and pass it to the `--bench-cmd` option. +3. (Optional) If you would like to vary the settings of `vllm serve`, create a new JSON file and populate it with the parameter combinations you want to test. Pass the file path to `--serve-params`. + + - Example: Tuning `--max-num-seqs` and `--max-num-batched-tokens`: + + ```json + [ + { + "max_num_seqs": 32, + "max_num_batched_tokens": 1024 + }, + { + "max_num_seqs": 64, + "max_num_batched_tokens": 1024 + }, + { + "max_num_seqs": 64, + "max_num_batched_tokens": 2048 + }, + { + "max_num_seqs": 128, + "max_num_batched_tokens": 2048 + }, + { + "max_num_seqs": 128, + "max_num_batched_tokens": 4096 + }, + { + "max_num_seqs": 256, + "max_num_batched_tokens": 4096 + } + ] + ``` + +4. (Optional) If you would like to vary the settings of `vllm bench serve`, create a new JSON file and populate it with the parameter combinations you want to test. Pass the file path to `--bench-params`. + + - Example: Using different input/output lengths for random dataset: + + ```json + [ + { + "random_input_len": 128, + "random_output_len": 32 + }, + { + "random_input_len": 256, + "random_output_len": 64 + }, + { + "random_input_len": 512, + "random_output_len": 128 + } + ] + ``` + +5. Determine where you want to save the results, and pass that to `--output-dir`. + +Example command: + +```bash +python -m vllm.benchmarks.sweep.serve \ + --serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \ + --bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \ + --serve-params benchmarks/serve_hparams.json \ + --bench-params benchmarks/bench_hparams.json \ + -o benchmarks/results +``` + +!!! important + If both `--serve-params` and `--bench-params` are passed, the script will iterate over the Cartesian product between them. + You can use `--dry-run` to preview the commands to be run. + + We only start the server once for each `--serve-params`, and keep it running for multiple `--bench-params`. + Between each benchmark run, we call the `/reset_prefix_cache` and `/reset_mm_cache` endpoints to get a clean slate for the next run. + In case you are using a custom `--serve-cmd`, you can override the commands used for resetting the state by setting `--after-bench-cmd`. + +!!! note + By default, each parameter combination is run 3 times to make the results more reliable. You can adjust the number of runs by setting `--num-runs`. + +!!! tip + You can use the `--resume` option to continue the parameter sweep if one of the runs failed. + +### SLA Auto-Tuner + +[`vllm/benchmarks/sweep/serve_sla.py`](../../vllm/benchmarks/sweep/serve_sla.py) is a wrapper over [`vllm/benchmarks/sweep/serve.py`](../../vllm/benchmarks/sweep/serve.py) that tunes either the request rate or concurrency (choose using `--sla-variable`) in order to satisfy the SLA constraints given by `--sla-params`. + +For example, to ensure E2E latency within different target values for 99% of requests: + +```json +[ + { + "p99_e2el_ms": "<=200" + }, + { + "p99_e2el_ms": "<=500" + }, + { + "p99_e2el_ms": "<=1000" + }, + { + "p99_e2el_ms": "<=2000" + } +] +``` + +Example command: + +```bash +python -m vllm.benchmarks.sweep.serve_sla \ + --serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \ + --bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \ + --serve-params benchmarks/serve_hparams.json \ + --bench-params benchmarks/bench_hparams.json \ + --sla-params benchmarks/sla_hparams.json \ + --sla-variable max_concurrency \ + -o benchmarks/results +``` + +The algorithm for adjusting the SLA variable is as follows: + +1. Run the benchmark with infinite QPS, and use the corresponding metrics to determine the initial value of the variable. + - For example, the initial request rate is set to the concurrency under infinite QPS. +2. If the SLA is still satisfied, keep doubling the value until the SLA is no longer satisfied. This gives a relatively narrow window that contains the point where the SLA is barely satisfied. +3. Apply binary search over the window to find the maximum value that still satisfies the SLA. + +!!! important + SLA tuning is applied over each combination of `--serve-params`, `--bench-params`, and `--sla-params`. + + For a given combination of `--serve-params` and `--bench-params`, we share the benchmark results across `--sla-params` to avoid rerunning benchmarks with the same SLA variable value. + +### Visualizer + +[`vllm/benchmarks/sweep/plot.py`](../../vllm/benchmarks/sweep/plot.py) can be used to plot performance curves from parameter sweep results. + +Example command: + +```bash +python -m vllm.benchmarks.sweep.plot benchmarks/results/ \ + --var-x max_concurrency \ + --row-by random_input_len \ + --col-by random_output_len \ + --curve-by api_server_count,max_num_batched_tokens \ + --filter-by 'max_concurrency<=1024' +``` + +!!! tip + You can use `--dry-run` to preview the figures to be plotted. ## Performance Benchmarks @@ -916,7 +1118,7 @@ For more results visualization, check the [visualizing the results](https://gith The latest performance results are hosted on the public [vLLM Performance Dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm). -More information on the performance benchmarks and their parameters can be found in [Benchmark README](https://github.com/intel-ai-tce/vllm/blob/more_cpu_models/.buildkite/nightly-benchmarks/README.md) and [performance benchmark description](gh-file:.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md). +More information on the performance benchmarks and their parameters can be found in [Benchmark README](https://github.com/intel-ai-tce/vllm/blob/more_cpu_models/.buildkite/nightly-benchmarks/README.md) and [performance benchmark description](../../.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md). ### Continuous Benchmarking @@ -942,12 +1144,10 @@ The benchmarking currently runs on a predefined set of models configured in the All continuous benchmarking results are automatically published to the public [vLLM Performance Dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm). -[](){ #nightly-benchmarks } - ## Nightly Benchmarks These compare vLLM's performance against alternatives (`tgi`, `trt-llm`, and `lmdeploy`) when there are major updates of vLLM (e.g., bumping up to a new version). They are primarily intended for consumers to evaluate when to choose vLLM over other options and are triggered on every commit with both the `perf-benchmarks` and `nightly-benchmarks` labels. The latest nightly benchmark results are shared in major release blog posts such as [vLLM v0.6.0](https://blog.vllm.ai/2024/09/05/perf-update.html). -More information on the nightly benchmarks and their parameters can be found [here](gh-file:.buildkite/nightly-benchmarks/nightly-descriptions.md). +More information on the nightly benchmarks and their parameters can be found [here](../../.buildkite/nightly-benchmarks/nightly-descriptions.md). diff --git a/docs/contributing/ci/failures.md b/docs/contributing/ci/failures.md index d7e2dfbca876..dad04e75fbb6 100644 --- a/docs/contributing/ci/failures.md +++ b/docs/contributing/ci/failures.md @@ -64,7 +64,7 @@ Download the full log file from Buildkite locally. Strip timestamps and colorization: - +[.buildkite/scripts/ci-clean-log.sh](../../../.buildkite/scripts/ci-clean-log.sh) ```bash ./ci-clean-log.sh ci.log @@ -87,7 +87,7 @@ tail -525 ci_build.log | wl-copy CI test failures may be flaky. Use a bash loop to run repeatedly: - +[.buildkite/scripts/rerun-test.sh](../../../.buildkite/scripts/rerun-test.sh) ```bash ./rerun-test.sh tests/v1/engine/test_engine_core_client.py::test_kv_cache_events[True-tcp] diff --git a/docs/contributing/ci/update_pytorch_version.md b/docs/contributing/ci/update_pytorch_version.md index 3dae62dd5d94..f983c25f26ee 100644 --- a/docs/contributing/ci/update_pytorch_version.md +++ b/docs/contributing/ci/update_pytorch_version.md @@ -5,7 +5,7 @@ release in CI/CD. It is standard practice to submit a PR to update the PyTorch version as early as possible when a new [PyTorch stable release](https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-cadence) becomes available. This process is non-trivial due to the gap between PyTorch -releases. Using as an example, this document outlines common steps to achieve this +releases. Using as an example, this document outlines common steps to achieve this update along with a list of potential issues and how to address them. ## Test PyTorch release candidates (RCs) @@ -85,9 +85,9 @@ and timeout. Additionally, since vLLM's fastcheck pipeline runs in read-only mod it doesn't populate the cache, so re-running it to warm up the cache is ineffective. -While ongoing efforts like [#17419](gh-issue:17419) +While ongoing efforts like address the long build time at its source, the current workaround is to set `VLLM_CI_BRANCH` -to a custom branch provided by @khluu (`VLLM_CI_BRANCH=khluu/use_postmerge_q`) +to a custom branch provided by @khluu (`VLLM_CI_BRANCH=khluu/long_build`) when manually triggering a build on Buildkite. This branch accomplishes two things: 1. Increase the timeout limit to 10 hours so that the build doesn't time out. @@ -100,35 +100,17 @@ to warm it up so that future builds are faster. ## Update dependencies -Several vLLM dependencies, such as FlashInfer, also depend on PyTorch and need +Several vLLM dependencies like xFormers depend on PyTorch and need to be updated accordingly. Rather than waiting for all of them to publish new releases (which would take too much time), they can be built from source to unblock the update process. -### FlashInfer - -Here is how to build and install it from source with `torch2.7.0+cu128` in vLLM [Dockerfile](https://github.com/vllm-project/vllm/blob/27bebcd89792d5c4b08af7a65095759526f2f9e1/docker/Dockerfile#L259-L271): - -```bash -export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX' -export FLASHINFER_ENABLE_SM90=1 -uv pip install --system \ - --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.6.post1" -``` - -One caveat is that building FlashInfer from source adds approximately 30 -minutes to the vLLM build time. Therefore, it's preferable to cache the wheel in a -public location for immediate installation, such as [this FlashInfer wheel link](https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl). For future releases, contact the PyTorch release -team if you want to get the package published there. - ### xFormers -Similar to FlashInfer, here is how to build and install xFormers from source: - ```bash -export TORCH_CUDA_ARCH_LIST='7.0 7.5 8.0 8.9 9.0 10.0+PTX' +export TORCH_CUDA_ARCH_LIST='7.5 8.0+PTX 9.0a' MAX_JOBS=16 uv pip install --system \ - --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.30" + --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.32.post2" ``` ## Update all the different vLLM platforms @@ -138,5 +120,5 @@ to handle some platforms separately. The separation of requirements and Dockerfi for different platforms in vLLM CI/CD allows us to selectively choose which platforms to update. For instance, updating XPU requires the corresponding release from [Intel Extension for PyTorch](https://github.com/intel/intel-extension-for-pytorch) by Intel. -While updated vLLM to PyTorch 2.7.0 on CPU, CUDA, and ROCm, - completed the update for XPU. +While updated vLLM to PyTorch 2.7.0 on CPU, CUDA, and ROCm, + completed the update for XPU. diff --git a/docs/contributing/dockerfile/dockerfile.md b/docs/contributing/dockerfile/dockerfile.md index a7ff99aa26d5..14184b969366 100644 --- a/docs/contributing/dockerfile/dockerfile.md +++ b/docs/contributing/dockerfile/dockerfile.md @@ -1,6 +1,6 @@ # Dockerfile -We provide a to construct the image for running an OpenAI compatible server with vLLM. +We provide a [docker/Dockerfile](../../../docker/Dockerfile) to construct the image for running an OpenAI compatible server with vLLM. More information about deploying with Docker can be found [here](../../deployment/docker.md). Below is a visual representation of the multi-stage Dockerfile. The build graph contains the following nodes: diff --git a/docs/contributing/model/README.md b/docs/contributing/model/README.md index 36068bc14876..d8c40c519573 100644 --- a/docs/contributing/model/README.md +++ b/docs/contributing/model/README.md @@ -1,7 +1,7 @@ # Summary !!! important - Many decoder language models can now be automatically loaded using the [Transformers backend][transformers-backend] without having to implement them in vLLM. See if `vllm serve ` works first! + Many decoder language models can now be automatically loaded using the [Transformers backend](../../models/supported_models.md#transformers) without having to implement them in vLLM. See if `vllm serve ` works first! vLLM models are specialized [PyTorch](https://pytorch.org/) models that take advantage of various [features](../../features/README.md#compatibility-matrix) to optimize their performance. diff --git a/docs/contributing/model/basic.md b/docs/contributing/model/basic.md index aafdb1058e03..795bd5507a61 100644 --- a/docs/contributing/model/basic.md +++ b/docs/contributing/model/basic.md @@ -5,7 +5,7 @@ This guide walks you through the steps to implement a basic vLLM model. ## 1. Bring your model code First, clone the PyTorch model code from the source repository. -For instance, vLLM's [OPT model](gh-file:vllm/model_executor/models/opt.py) was adapted from +For instance, vLLM's [OPT model](../../../vllm/model_executor/models/opt.py) was adapted from HuggingFace's [modeling_opt.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py) file. !!! warning @@ -73,8 +73,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: ... ``` @@ -83,7 +83,7 @@ def forward( Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings. If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM. -For reference, check out our [Llama implementation](gh-file:vllm/model_executor/models/llama.py). vLLM already supports a large number of models. It is recommended to find a model similar to yours and adapt it to your model's architecture. Check out for more examples. +For reference, check out our [Llama implementation](../../../vllm/model_executor/models/llama.py). vLLM already supports a large number of models. It is recommended to find a model similar to yours and adapt it to your model's architecture. Check out [vllm/model_executor/models](../../../vllm/model_executor/models) for more examples. ## 3. (Optional) Implement tensor parallelism and quantization support @@ -130,22 +130,22 @@ We consider 3 different scenarios: 2. Models that combine Mamba layers (either Mamba-1 or Mamba-2) together with attention layers. 3. Models that combine Mamba-like mechanisms (e.g., Linear Attention, ShortConv) together with attention layers. -For case (1), we recommend looking at the implementation of [`MambaForCausalLM`](gh-file:vllm/model_executor/models/mamba.py) (for Mamba-1) or [`Mamba2ForCausalLM`](gh-file:vllm/model_executor/models/mamba2.py) (for Mamba-2) as a reference. +For case (1), we recommend looking at the implementation of [`MambaForCausalLM`](../../../vllm/model_executor/models/mamba.py) (for Mamba-1) or [`Mamba2ForCausalLM`](../../../vllm/model_executor/models/mamba2.py) (for Mamba-2) as a reference. The model should inherit protocol `IsAttentionFree` and also implement class methods `get_mamba_state_dtype_from_config` and `get_mamba_state_shape_from_config` to calculate the state shapes and data types from the config. -For the mamba layers themselves, please use the [`MambaMixer`](gh-file:vllm/model_executor/layers/mamba/mamba_mixer.py) (for Mamba-1) or [`MambaMixer2`](gh-file:vllm/model_executor/layers/mamba/mamba_mixer2.py) (for Mamba-2) classes. +For the mamba layers themselves, please use the [`MambaMixer`](../../../vllm/model_executor/layers/mamba/mamba_mixer.py) (for Mamba-1) or [`MambaMixer2`](../../../vllm/model_executor/layers/mamba/mamba_mixer2.py) (for Mamba-2) classes. Please *do not* use the `MambaCacheManager` (deprecated in V1) or replicate any of the V0-specific code paths in the existing model implementations. V0-only classes and code will be removed in the very near future. -The model should also be added to the `MODELS_CONFIG_MAP` dictionary in to ensure that the runtime defaults are optimized. +The model should also be added to the `MODELS_CONFIG_MAP` dictionary in [vllm/model_executor/models/config.py](../../../vllm/model_executor/models/config.py) to ensure that the runtime defaults are optimized. -For case (2), we recommend using as a reference the implementation of [`JambaForCausalLM`](gh-file:vllm/model_executor/models/jamba.py) (for an example of a model that uses Mamba-1 and attention together) or [`BambaForCausalLM`](gh-file:vllm/model_executor/models/bamba.py) (for an example of a model that uses Mamba-2 and attention together). +For case (2), we recommend using as a reference the implementation of [`JambaForCausalLM`](../../../vllm/model_executor/models/jamba.py) (for an example of a model that uses Mamba-1 and attention together) or [`BambaForCausalLM`](../../../vllm/model_executor/models/bamba.py) (for an example of a model that uses Mamba-2 and attention together). These models should follow the same instructions as case (1), but they should inherit protocol `IsHybrid` (instead of `IsAttentionFree`) and it is *not* necessary to add them to the `MODELS_CONFIG_MAP` (their runtime defaults will be inferred from the protocol). -For case (3), we recommend looking at the implementation of [`MiniMaxText01ForCausalLM`](gh-file:vllm/model_executor/models/minimax_text_01.py) or [`Lfm2ForCausalLM`](gh-file:vllm/model_executor/models/lfm2.py) as a reference, which use custom "mamba-like" layers `MiniMaxText01LinearAttention` and `ShortConv` respectively. +For case (3), we recommend looking at the implementation of [`MiniMaxText01ForCausalLM`](../../../vllm/model_executor/models/minimax_text_01.py) or [`Lfm2ForCausalLM`](../../../vllm/model_executor/models/lfm2.py) as a reference, which use custom "mamba-like" layers `MiniMaxText01LinearAttention` and `ShortConv` respectively. Please follow the same guidelines as case (2) for implementing these models. We use "mamba-like" to refer to layers that posses a state that is updated in-place, rather than being appended-to (like KV cache for attention). For implementing new custom mamba-like layers, one should inherit from `MambaBase` and implement the methods `get_state_dtype`, `get_state_shape` to calculate the data types and state shapes at runtime, as well as `mamba_type` and `get_attn_backend`. It is also necessary to implement the "attention meta-data" class which handles the meta-data that is common across all layers. -Please see [`LinearAttentionMetadata`](gh-file:vllm/v1/attention/backends/linear_attn.py) or [`ShortConvAttentionMetadata`](gh-file:v1/attention/backends/short_conv_attn.py) for examples of this. +Please see [`LinearAttentionMetadata`](../../../vllm/v1/attention/backends/linear_attn.py) or [`ShortConvAttentionMetadata`](../../../vllm/v1/attention/backends/short_conv_attn.py) for examples of this. Finally, if one wants to support torch compile and CUDA graphs, it necessary to wrap the call to the mamba-like layer inside a custom op and register it. -Please see the calls to `direct_register_custom_op` in or for examples of this. -The new custom op should then be added to the list `_attention_ops` in to ensure that piecewise CUDA graphs works as intended. +Please see the calls to `direct_register_custom_op` in [vllm/model_executor/models/minimax_text_01.py](../../../vllm/model_executor/models/minimax_text_01.py) or [vllm/model_executor/layers/mamba/short_conv.py](../../../vllm/model_executor/layers/mamba/short_conv.py) for examples of this. +The new custom op should then be added to the list `_attention_ops` in [vllm/config/compilation.py](../../../vllm/config/compilation.py) to ensure that piecewise CUDA graphs works as intended. diff --git a/docs/contributing/model/multimodal.md b/docs/contributing/model/multimodal.md index 724dc2284e28..4e74afc688cf 100644 --- a/docs/contributing/model/multimodal.md +++ b/docs/contributing/model/multimodal.md @@ -16,7 +16,7 @@ Further update the model as follows: ... @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "" @@ -45,14 +45,14 @@ Further update the model as follows: ... def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor: - assert self.vision_encoder is not None image_features = self.vision_encoder(image_input) return self.multi_modal_projector(image_features) def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - + self, + **kwargs: object, + ) -> MultiModalEmbeddings | None: # Validate the multimodal input keyword arguments image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: @@ -110,7 +110,7 @@ to return the maximum number of input items for each modality supported by the m For example, if the model supports any number of images but only one video per prompt: ```python -def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: +def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None, "video": 1} ``` @@ -258,7 +258,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) @@ -421,8 +421,10 @@ Assuming that the memory usage increases with the number of tokens, the dummy in ```python def get_image_size_with_most_features(self) -> ImageSize: image_processor = self.get_image_processor() - return ImageSize(width=image_processor.size["width"], - height=image_processor.size["height"]) + return ImageSize( + width=image_processor.size["width"], + height=image_processor.size["height"], + ) ``` Fuyu does not expect image placeholders in the inputs to HF processor, so @@ -452,10 +454,12 @@ Assuming that the memory usage increases with the number of tokens, the dummy in return { "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides) + self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } ``` @@ -503,7 +507,7 @@ return a schema of the tensors outputted by the HF processor that are related to ``` !!! note - Our [actual code](gh-file:vllm/model_executor/models/llava.py) additionally supports + Our [actual code](../../../vllm/model_executor/models/llava.py) additionally supports pre-computed image embeddings, which can be passed to be model via the `image_embeds` argument. === "With postprocessing: Fuyu" @@ -565,7 +569,7 @@ return a schema of the tensors outputted by the HF processor that are related to ``` !!! note - Our [actual code](gh-file:vllm/model_executor/models/fuyu.py) has special handling + Our [actual code](../../../vllm/model_executor/models/fuyu.py) has special handling for text-only inputs to prevent unnecessary warnings from HF processor. !!! note @@ -744,8 +748,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies image_width=image_size.width, image_height=image_size.height, ) - image_tokens = ([_IMAGE_TOKEN_ID] * ncols + - [_NEWLINE_TOKEN_ID]) * nrows + image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows return PromptUpdateDetails.select_token_id( image_tokens + [bos_token_id], @@ -781,8 +784,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies image_width=image_size.width, image_height=image_size.height, ) - image_tokens = ([_IMAGE_TOKEN_ID] * ncols + - [_NEWLINE_TOKEN_ID]) * nrows + image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows return PromptUpdateDetails.select_token_id( image_tokens + [bos_token_id], @@ -810,9 +812,11 @@ to register them to the multi-modal registry: from vllm.model_executor.models.interfaces import SupportsMultiModal + from vllm.multimodal import MULTIMODAL_REGISTRY -+ @MULTIMODAL_REGISTRY.register_processor(YourMultiModalProcessor, -+ info=YourProcessingInfo, -+ dummy_inputs=YourDummyInputsBuilder) ++ @MULTIMODAL_REGISTRY.register_processor( ++ YourMultiModalProcessor, ++ info=YourProcessingInfo, ++ dummy_inputs=YourDummyInputsBuilder, ++ ) class YourModelForImage2Seq(nn.Module, SupportsMultiModal): ``` @@ -824,8 +828,8 @@ Some HF processors directly insert feature tokens without replacing anything in Examples: -- BLIP-2 (insert at start of prompt): -- Molmo (insert after `<|endoftext|>` token): +- BLIP-2 (insert at start of prompt): [vllm/model_executor/models/blip2.py](../../../vllm/model_executor/models/blip2.py) +- Molmo (insert after `<|endoftext|>` token): [vllm/model_executor/models/molmo.py](../../../vllm/model_executor/models/molmo.py) ### Handling prompt updates unrelated to multi-modal data @@ -833,9 +837,9 @@ Examples: Examples: -- Chameleon (appends `sep_token`): -- Fuyu (appends `boa_token`): -- Molmo (applies chat template which is not defined elsewhere): +- Chameleon (appends `sep_token`): [vllm/model_executor/models/chameleon.py](../../../vllm/model_executor/models/chameleon.py) +- Fuyu (appends `boa_token`): [vllm/model_executor/models/fuyu.py](../../../vllm/model_executor/models/fuyu.py) +- Molmo (applies chat template which is not defined elsewhere): [vllm/model_executor/models/molmo.py](../../../vllm/model_executor/models/molmo.py) ### Custom HF processor @@ -843,6 +847,6 @@ Some models don't define an HF processor class on HF Hub. In that case, you can Examples: -- DeepSeek-VL2: -- InternVL: -- Qwen-VL: +- DeepSeek-VL2: [vllm/model_executor/models/deepseek_vl2.py](../../../vllm/model_executor/models/deepseek_vl2.py) +- InternVL: [vllm/model_executor/models/internvl.py](../../../vllm/model_executor/models/internvl.py) +- Qwen-VL: [vllm/model_executor/models/qwen_vl.py](../../../vllm/model_executor/models/qwen_vl.py) diff --git a/docs/contributing/model/registration.md b/docs/contributing/model/registration.md index 35f35ffa4cde..400d0f75caca 100644 --- a/docs/contributing/model/registration.md +++ b/docs/contributing/model/registration.md @@ -8,11 +8,11 @@ This page provides detailed instructions on how to do so. ## Built-in models -To add a model directly to the vLLM library, start by forking our [GitHub repository](https://github.com/vllm-project/vllm) and then [build it from source][build-from-source]. +To add a model directly to the vLLM library, start by forking our [GitHub repository](https://github.com/vllm-project/vllm) and then [build it from source](../../getting_started/installation/gpu.md#build-wheel-from-source). This gives you the ability to modify the codebase and test your model. -After you have implemented your model (see [tutorial](basic.md)), put it into the directory. -Then, add your model class to `_VLLM_MODELS` in so that it is automatically registered upon importing vLLM. +After you have implemented your model (see [tutorial](basic.md)), put it into the [vllm/model_executor/models](../../../vllm/model_executor/models) directory. +Then, add your model class to `_VLLM_MODELS` in [vllm/model_executor/models/registry.py](../../../vllm/model_executor/models/registry.py) so that it is automatically registered upon importing vLLM. Finally, update our [list of supported models](../../models/supported_models.md) to promote your model! !!! important @@ -42,7 +42,7 @@ def register(): ModelRegistry.register_model( "YourModelForCausalLM", - "your_code:YourModelForCausalLM" + "your_code:YourModelForCausalLM", ) ``` diff --git a/docs/contributing/model/tests.md b/docs/contributing/model/tests.md index 1206ad36771e..3ccd90cc66f7 100644 --- a/docs/contributing/model/tests.md +++ b/docs/contributing/model/tests.md @@ -9,7 +9,7 @@ Without them, the CI for your PR will fail. ### Model loading -Include an example HuggingFace repository for your model in . +Include an example HuggingFace repository for your model in [tests/models/registry.py](../../../tests/models/registry.py). This enables a unit test that loads dummy weights to ensure that the model can be initialized in vLLM. !!! important @@ -26,26 +26,24 @@ Passing these tests provides more confidence that your implementation is correct ### Model correctness -These tests compare the model outputs of vLLM against [HF Transformers](https://github.com/huggingface/transformers). You can add new tests under the subdirectories of . +These tests compare the model outputs of vLLM against [HF Transformers](https://github.com/huggingface/transformers). You can add new tests under the subdirectories of [tests/models](../../../tests/models). #### Generative models -For [generative models](../../models/generative_models.md), there are two levels of correctness tests, as defined in : +For [generative models](../../models/generative_models.md), there are two levels of correctness tests, as defined in [tests/models/utils.py](../../../tests/models/utils.py): - Exact correctness (`check_outputs_equal`): The text outputted by vLLM should exactly match the text outputted by HF. - Logprobs similarity (`check_logprobs_close`): The logprobs outputted by vLLM should be in the top-k logprobs outputted by HF, and vice versa. #### Pooling models -For [pooling models](../../models/pooling_models.md), we simply check the cosine similarity, as defined in . - -[](){ #mm-processing-tests } +For [pooling models](../../models/pooling_models.md), we simply check the cosine similarity, as defined in [tests/models/utils.py](../../../tests/models/utils.py). ### Multi-modal processing #### Common tests -Adding your model to verifies that the following input combinations result in the same outputs: +Adding your model to [tests/models/multimodal/processing/test_common.py](../../../tests/models/multimodal/processing/test_common.py) verifies that the following input combinations result in the same outputs: - Text + multi-modal data - Tokens + multi-modal data @@ -54,6 +52,6 @@ Adding your model to #### Model-specific tests -You can add a new file under to run tests that only apply to your model. +You can add a new file under [tests/models/multimodal/processing](../../../tests/models/multimodal/processing) to run tests that only apply to your model. -For example, if the HF processor for your model accepts user-specified keyword arguments, you can verify that the keyword arguments are being applied correctly, such as in . +For example, if the HF processor for your model accepts user-specified keyword arguments, you can verify that the keyword arguments are being applied correctly, such as in [tests/models/multimodal/processing/test_phi3v.py](../../../tests/models/multimodal/processing/test_phi3v.py). diff --git a/docs/contributing/model/transcription.md b/docs/contributing/model/transcription.md index 62e58e5c6ac5..a590ecd6a1a2 100644 --- a/docs/contributing/model/transcription.md +++ b/docs/contributing/model/transcription.md @@ -15,8 +15,9 @@ Declare supported languages and capabilities: - Set `supports_transcription_only=True` if the model should not serve text generation (eg Whisper). ??? code "supported_languages and supports_transcription_only" + ```python - from typing import ClassVar, Mapping, Optional, Literal + from typing import ClassVar, Mapping, Literal import numpy as np import torch from torch import nn @@ -43,6 +44,7 @@ Provide an ASR configuration via [get_speech_to_text_config][vllm.model_executor This is for controlling general behavior of the API when serving your model: ??? code "get_speech_to_text_config()" + ```python class YourASRModel(nn.Module, SupportsTranscription): ... @@ -71,6 +73,7 @@ Implement the prompt construction via [get_generation_prompt][vllm.model_executo Return a dict containing `multi_modal_data` with the audio, and either a `prompt` string or `prompt_token_ids`: ??? code "get_generation_prompt()" + ```python class YourASRModel(nn.Module, SupportsTranscription): ... @@ -81,10 +84,10 @@ Return a dict containing `multi_modal_data` with the audio, and either a `prompt audio: np.ndarray, stt_config: SpeechToTextConfig, model_config: ModelConfig, - language: Optional[str], + language: str | None, task_type: Literal["transcribe", "translate"], request_prompt: str, - to_language: Optional[str], + to_language: str | None, ) -> PromptType: # Example with a free-form instruction prompt task_word = "Transcribe" if task_type == "transcribe" else "Translate" @@ -107,6 +110,7 @@ Return a dict containing `multi_modal_data` with the audio, and either a `prompt Return a dict with separate `encoder_prompt` and `decoder_prompt` entries: ??? code "get_generation_prompt()" + ```python class YourASRModel(nn.Module, SupportsTranscription): ... @@ -117,10 +121,10 @@ Return a dict with separate `encoder_prompt` and `decoder_prompt` entries: audio: np.ndarray, stt_config: SpeechToTextConfig, model_config: ModelConfig, - language: Optional[str], + language: str | None, task_type: Literal["transcribe", "translate"], request_prompt: str, - to_language: Optional[str], + to_language: str | None, ) -> PromptType: if language is None: raise ValueError("Language must be specified") @@ -148,12 +152,16 @@ Language validation via [validate_language][vllm.model_executor.models.interface If your model requires a language and you want a default, override this method (see Whisper): ??? code "validate_language()" + ```python @classmethod - def validate_language(cls, language: Optional[str]) -> Optional[str]: + def validate_language(cls, language: str | None) -> str | None: if language is None: logger.warning( - "Defaulting to language='en'. If you wish to transcribe audio in a different language, pass the `language` field.") + "Defaulting to language='en'. If you wish to transcribe " + "audio in a different language, pass the `language` field " + "in the TranscriptionRequest." + ) language = "en" return super().validate_language(language) ``` @@ -165,6 +173,7 @@ Token accounting for streaming via [get_num_audio_tokens][vllm.model_executor.mo Provide a fast duration→token estimate to improve streaming usage statistics: ??? code "get_num_audio_tokens()" + ```python class YourASRModel(nn.Module, SupportsTranscription): ... @@ -175,7 +184,7 @@ Provide a fast duration→token estimate to improve streaming usage statistics: audio_duration_s: float, stt_config: SpeechToTextConfig, model_config: ModelConfig, - ) -> Optional[int]: + ) -> int | None: # Return None if unknown; otherwise return an estimate. return int(audio_duration_s * stt_config.sample_rate // 320) # example ``` @@ -191,6 +200,7 @@ The API server takes care of basic audio I/O and optional chunking before buildi Relevant server logic: ??? code "_preprocess_speech_to_text()" + ```python # vllm/entrypoints/openai/speech_to_text.py async def _preprocess_speech_to_text(...): @@ -238,9 +248,9 @@ No extra registration is required beyond having your model class available via t ## Examples in-tree -- Whisper encoder–decoder (audio-only): -- Voxtral decoder-only (audio embeddings + LLM): -- Gemma3n decoder-only with fixed instruction prompt: +- Whisper encoder–decoder (audio-only): [vllm/model_executor/models/whisper.py](../../../vllm/model_executor/models/whisper.py) +- Voxtral decoder-only (audio embeddings + LLM): [vllm/model_executor/models/voxtral.py](../../../vllm/model_executor/models/voxtral.py) +- Gemma3n decoder-only with fixed instruction prompt: [vllm/model_executor/models/gemma3n_mm.py](../../../vllm/model_executor/models/gemma3n_mm.py) ## Test with the API @@ -268,7 +278,7 @@ Once your model implements `SupportsTranscription`, you can test the endpoints ( http://localhost:8000/v1/audio/translations ``` -Or check out more examples in . +Or check out more examples in [examples/online_serving](../../../examples/online_serving). !!! note - If your model handles chunking internally (e.g., via its processor or encoder), set `min_energy_split_window_size=None` in the returned `SpeechToTextConfig` to disable server-side chunking. diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index f6a73e99546e..fed286f4b634 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -33,7 +33,7 @@ Traces can be visualized using . #### Offline Inference -Refer to for an example. +Refer to [examples/offline_inference/simple_profiling.py](../../examples/offline_inference/simple_profiling.py) for an example. #### OpenAI Server @@ -180,9 +180,13 @@ The profiling traces generated by the continuous profiling workflow are publicly The Python standard library includes [cProfile](https://docs.python.org/3/library/profile.html) for profiling Python code. vLLM includes a couple of helpers that make it easy to apply it to a section of vLLM. -Both the `vllm.utils.cprofile` and `vllm.utils.cprofile_context` functions can be +Both the `vllm.utils.profiling.cprofile` and `vllm.utils.profiling.cprofile_context` functions can be used to profile a section of code. +!!! note + The legacy import paths `vllm.utils.cprofile` and `vllm.utils.cprofile_context` are deprecated. + Please use `vllm.utils.profiling.cprofile` and `vllm.utils.profiling.cprofile_context` instead. + ### Example usage - decorator The first helper is a Python decorator that can be used to profile a function. @@ -190,9 +194,9 @@ If a filename is specified, the profile will be saved to that file. If no filena specified, profile data will be printed to stdout. ```python -import vllm.utils +from vllm.utils.profiling import cprofile -@vllm.utils.cprofile("expensive_function.prof") +@cprofile("expensive_function.prof") def expensive_function(): # some expensive code pass @@ -204,13 +208,13 @@ The second helper is a context manager that can be used to profile a block of code. Similar to the decorator, the filename is optional. ```python -import vllm.utils +from vllm.utils.profiling import cprofile_context def another_function(): # more expensive code pass -with vllm.utils.cprofile_context("another_function.prof"): +with cprofile_context("another_function.prof"): another_function() ``` diff --git a/docs/deployment/docker.md b/docs/deployment/docker.md index 1f19f2fecfab..d07358b85a5e 100644 --- a/docs/deployment/docker.md +++ b/docs/deployment/docker.md @@ -1,7 +1,5 @@ # Using Docker -[](){ #deployment-docker-pre-built-image } - ## Use vLLM's Official Docker Image vLLM offers an official Docker image for deployment. @@ -10,7 +8,7 @@ The image can be used to run OpenAI compatible server and is available on Docker ```bash docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ - --env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \ + --env "HF_TOKEN=$HF_TOKEN" \ -p 8000:8000 \ --ipc=host \ vllm/vllm-openai:latest \ @@ -22,7 +20,7 @@ This image can also be used with other container engines such as [Podman](https: ```bash podman run --device nvidia.com/gpu=all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ - --env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \ + --env "HF_TOKEN=$HF_TOKEN" \ -p 8000:8000 \ --ipc=host \ docker.io/vllm/vllm-openai:latest \ @@ -37,7 +35,7 @@ You can add any other [engine-args](../configuration/engine_args.md) you need af memory to share data between processes under the hood, particularly for tensor parallel inference. !!! note - Optional dependencies are not included in order to avoid licensing issues (e.g. ). + Optional dependencies are not included in order to avoid licensing issues (e.g. ). If you need to use those dependencies (having accepted the license terms), create a custom Dockerfile on top of the base image with an extra layer that installs them: @@ -62,11 +60,9 @@ You can add any other [engine-args](../configuration/engine_args.md) you need af RUN uv pip install --system git+https://github.com/huggingface/transformers.git ``` -[](){ #deployment-docker-build-image-from-source } - ## Building vLLM's Docker Image from Source -You can build and run vLLM from source via the provided . To build vLLM: +You can build and run vLLM from source via the provided [docker/Dockerfile](../../docker/Dockerfile). To build vLLM: ```bash # optionally specifies: --build-arg max_jobs=8 --build-arg nvcc_threads=2 @@ -128,7 +124,7 @@ To run vLLM with the custom-built Docker image: docker run --runtime nvidia --gpus all \ -v ~/.cache/huggingface:/root/.cache/huggingface \ -p 8000:8000 \ - --env "HUGGING_FACE_HUB_TOKEN=" \ + --env "HF_TOKEN=" \ vllm/vllm-openai ``` diff --git a/docs/deployment/frameworks/anyscale.md b/docs/deployment/frameworks/anyscale.md index 9957c5b14134..965742ec0726 100644 --- a/docs/deployment/frameworks/anyscale.md +++ b/docs/deployment/frameworks/anyscale.md @@ -1,11 +1,9 @@ # Anyscale -[](){ #deployment-anyscale } - [Anyscale](https://www.anyscale.com) is a managed, multi-cloud platform developed by the creators of Ray. Anyscale automates the entire lifecycle of Ray clusters in your AWS, GCP, or Azure account, delivering the flexibility of open-source Ray -without the operational overhead of maintaining Kubernetes control planes, configuring autoscalers, managing observability stacks, or manually managing head and worker nodes with helper scripts like . +without the operational overhead of maintaining Kubernetes control planes, configuring autoscalers, managing observability stacks, or manually managing head and worker nodes with helper scripts like [examples/online_serving/run_cluster.sh](../../../examples/online_serving/run_cluster.sh). When serving large language models with vLLM, Anyscale can rapidly provision [production-ready HTTPS endpoints](https://docs.anyscale.com/examples/deploy-ray-serve-llms) or [fault-tolerant batch inference jobs](https://docs.anyscale.com/examples/ray-data-llm). diff --git a/docs/deployment/frameworks/cerebrium.md b/docs/deployment/frameworks/cerebrium.md index 1f233c3204a1..960347d9525c 100644 --- a/docs/deployment/frameworks/cerebrium.md +++ b/docs/deployment/frameworks/cerebrium.md @@ -63,7 +63,7 @@ If successful, you should be returned a CURL command that you can call inference ??? console "Command" - ```python + ```bash curl -X POST https://api.cortex.cerebrium.ai/v4/p-xxxxxx/vllm/run \ -H 'Content-Type: application/json' \ -H 'Authorization: ' \ @@ -81,7 +81,7 @@ You should get a response like: ??? console "Response" - ```python + ```json { "run_id": "52911756-3066-9ae8-bcc9-d9129d1bd262", "result": { diff --git a/docs/deployment/frameworks/dstack.md b/docs/deployment/frameworks/dstack.md index fe4d87f78f2a..9d2c7f5bb565 100644 --- a/docs/deployment/frameworks/dstack.md +++ b/docs/deployment/frameworks/dstack.md @@ -83,7 +83,7 @@ After the provisioning, you can interact with the model by using the OpenAI SDK: client = OpenAI( base_url="https://gateway.", - api_key="" + api_key="", ) completion = client.chat.completions.create( @@ -93,7 +93,7 @@ After the provisioning, you can interact with the model by using the OpenAI SDK: "role": "user", "content": "Compose a poem that explains the concept of recursion in programming.", } - ] + ], ) print(completion.choices[0].message.content) diff --git a/docs/deployment/frameworks/haystack.md b/docs/deployment/frameworks/haystack.md index 836305cf15c4..b53b829d6d3c 100644 --- a/docs/deployment/frameworks/haystack.md +++ b/docs/deployment/frameworks/haystack.md @@ -34,7 +34,7 @@ pip install vllm haystack-ai api_key=Secret.from_token("VLLM-PLACEHOLDER-API-KEY"), model="mistralai/Mistral-7B-Instruct-v0.1", api_base_url="http://{your-vLLM-host-ip}:{your-vLLM-host-port}/v1", - generation_kwargs = {"max_tokens": 512} + generation_kwargs={"max_tokens": 512}, ) response = generator.run( diff --git a/docs/deployment/frameworks/hf_inference_endpoints.md b/docs/deployment/frameworks/hf_inference_endpoints.md index 75a234bdf142..d39bb9a899c8 100644 --- a/docs/deployment/frameworks/hf_inference_endpoints.md +++ b/docs/deployment/frameworks/hf_inference_endpoints.md @@ -32,28 +32,28 @@ This is the easiest way to get started with vLLM on Hugging Face Inference Endpo import os client = OpenAI( - base_url = DEPLOYMENT_URL, - api_key = os.environ["HF_TOKEN"] # https://huggingface.co/settings/tokens + base_url=DEPLOYMENT_URL, + api_key=os.environ["HF_TOKEN"], # https://huggingface.co/settings/tokens ) chat_completion = client.chat.completions.create( - model = "HuggingFaceTB/SmolLM3-3B", - messages = [ + model="HuggingFaceTB/SmolLM3-3B", + messages=[ { "role": "user", "content": [ { "type": "text", - "text": "Give me a brief explanation of gravity in simple terms." + "text": "Give me a brief explanation of gravity in simple terms.", } - ] + ], } ], - stream = True + stream=True, ) for message in chat_completion: - print(message.choices[0].delta.content, end = "") + print(message.choices[0].delta.content, end="") ``` !!! note @@ -86,34 +86,34 @@ This method applies to models with the [`transformers` library tag](https://hugg import os client = OpenAI( - base_url = DEPLOYMENT_URL, - api_key = os.environ["HF_TOKEN"] # https://huggingface.co/settings/tokens + base_url=DEPLOYMENT_URL, + api_key=os.environ["HF_TOKEN"], # https://huggingface.co/settings/tokens ) chat_completion = client.chat.completions.create( - model = "ibm-granite/granite-docling-258M", - messages = [ + model="ibm-granite/granite-docling-258M", + messages=[ { "role": "user", "content": [ { "type": "image_url", "image_url": { - "url": "https://huggingface.co/ibm-granite/granite-docling-258M/resolve/main/assets/new_arxiv.png" - } + "url": "https://huggingface.co/ibm-granite/granite-docling-258M/resolve/main/assets/new_arxiv.png", + }, }, { "type": "text", - "text": "Convert this page to docling." - } + "text": "Convert this page to docling.", + }, ] } ], - stream = True + stream=True, ) for message in chat_completion: - print(message.choices[0].delta.content, end = "") + print(message.choices[0].delta.content, end="") ``` !!! note diff --git a/docs/deployment/frameworks/litellm.md b/docs/deployment/frameworks/litellm.md index 0d6c3729911a..9ea7c0373d2a 100644 --- a/docs/deployment/frameworks/litellm.md +++ b/docs/deployment/frameworks/litellm.md @@ -36,15 +36,16 @@ pip install vllm litellm ```python import litellm - messages = [{ "content": "Hello, how are you?","role": "user"}] + messages = [{"content": "Hello, how are you?", "role": "user"}] # hosted_vllm is prefix key word and necessary response = litellm.completion( - model="hosted_vllm/qwen/Qwen1.5-0.5B-Chat", # pass the vllm model name - messages=messages, - api_base="http://{your-vllm-server-host}:{your-vllm-server-port}/v1", - temperature=0.2, - max_tokens=80) + model="hosted_vllm/qwen/Qwen1.5-0.5B-Chat", # pass the vllm model name + messages=messages, + api_base="http://{your-vllm-server-host}:{your-vllm-server-port}/v1", + temperature=0.2, + max_tokens=80, + ) print(response) ``` diff --git a/docs/deployment/frameworks/lws.md b/docs/deployment/frameworks/lws.md index 3b9fa3ea43d6..14710a8dc333 100644 --- a/docs/deployment/frameworks/lws.md +++ b/docs/deployment/frameworks/lws.md @@ -35,7 +35,7 @@ Deploy the following yaml file `lws.yaml` - name: vllm-leader image: docker.io/vllm/vllm-openai:latest env: - - name: HUGGING_FACE_HUB_TOKEN + - name: HF_TOKEN value: command: - sh @@ -83,7 +83,7 @@ Deploy the following yaml file `lws.yaml` ephemeral-storage: 800Gi cpu: 125 env: - - name: HUGGING_FACE_HUB_TOKEN + - name: HF_TOKEN value: volumeMounts: - mountPath: /dev/shm diff --git a/docs/deployment/frameworks/retrieval_augmented_generation.md b/docs/deployment/frameworks/retrieval_augmented_generation.md index d86ab1600f12..8a5d18807d06 100644 --- a/docs/deployment/frameworks/retrieval_augmented_generation.md +++ b/docs/deployment/frameworks/retrieval_augmented_generation.md @@ -36,11 +36,11 @@ pip install -U vllm \ vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 ``` -1. Use the script: +1. Use the script: [examples/online_serving/retrieval_augmented_generation_with_langchain.py](../../../examples/online_serving/retrieval_augmented_generation_with_langchain.py) 1. Run the script - ```python + ```bash python retrieval_augmented_generation_with_langchain.py ``` @@ -74,10 +74,10 @@ pip install vllm \ vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 ``` -1. Use the script: +1. Use the script: [examples/online_serving/retrieval_augmented_generation_with_llamaindex.py](../../../examples/online_serving/retrieval_augmented_generation_with_llamaindex.py) 1. Run the script: - ```python + ```bash python retrieval_augmented_generation_with_llamaindex.py ``` diff --git a/docs/deployment/frameworks/streamlit.md b/docs/deployment/frameworks/streamlit.md index c119878f137a..1b214e1a32aa 100644 --- a/docs/deployment/frameworks/streamlit.md +++ b/docs/deployment/frameworks/streamlit.md @@ -20,7 +20,7 @@ pip install vllm streamlit openai vllm serve Qwen/Qwen1.5-0.5B-Chat ``` -1. Use the script: +1. Use the script: [examples/online_serving/streamlit_openai_chatbot_webserver.py](../../../examples/online_serving/streamlit_openai_chatbot_webserver.py) 1. Start the streamlit web UI and start to chat: diff --git a/docs/deployment/k8s.md b/docs/deployment/k8s.md index d3fda7eb6fb6..54031ec368b5 100644 --- a/docs/deployment/k8s.md +++ b/docs/deployment/k8s.md @@ -82,7 +82,7 @@ Next, start the vLLM server as a Kubernetes Deployment and Service: "vllm serve meta-llama/Llama-3.2-1B-Instruct" ] env: - - name: HUGGING_FACE_HUB_TOKEN + - name: HF_TOKEN valueFrom: secretKeyRef: name: hf-token-secret @@ -209,7 +209,7 @@ INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) "vllm serve mistralai/Mistral-7B-Instruct-v0.3 --trust-remote-code --enable-chunked-prefill --max_num_batched_tokens 1024" ] env: - - name: HUGGING_FACE_HUB_TOKEN + - name: HF_TOKEN valueFrom: secretKeyRef: name: hf-token-secret @@ -298,7 +298,7 @@ INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) "vllm serve mistralai/Mistral-7B-v0.3 --port 8000 --trust-remote-code --enable-chunked-prefill --max_num_batched_tokens 1024" ] env: - - name: HUGGING_FACE_HUB_TOKEN + - name: HF_TOKEN valueFrom: secretKeyRef: name: hf-token-secret diff --git a/docs/deployment/nginx.md b/docs/deployment/nginx.md index b3178e77f845..034068cddac3 100644 --- a/docs/deployment/nginx.md +++ b/docs/deployment/nginx.md @@ -2,8 +2,6 @@ This document shows how to launch multiple vLLM serving containers and use Nginx to act as a load balancer between the servers. -[](){ #nginxloadbalancer-nginx-build } - ## Build Nginx Container This guide assumes that you have just cloned the vLLM project and you're currently in the vllm root directory. @@ -27,8 +25,6 @@ Build the container: docker build . -f Dockerfile.nginx --tag nginx-lb ``` -[](){ #nginxloadbalancer-nginx-conf } - ## Create Simple Nginx Config file Create a file named `nginx_conf/nginx.conf`. Note that you can add as many servers as you'd like. In the below example we'll start with two. To add more, add another `server vllmN:8000 max_fails=3 fail_timeout=10000s;` entry to `upstream backend`. @@ -53,8 +49,6 @@ Create a file named `nginx_conf/nginx.conf`. Note that you can add as many serve } ``` -[](){ #nginxloadbalancer-nginx-vllm-container } - ## Build vLLM Container ```bash @@ -73,16 +67,12 @@ docker build \ --build-arg https_proxy=$https_proxy ``` -[](){ #nginxloadbalancer-nginx-docker-network } - ## Create Docker Network ```bash docker network create vllm_nginx ``` -[](){ #nginxloadbalancer-nginx-launch-container } - ## Launch vLLM Containers Notes: @@ -122,8 +112,6 @@ Notes: !!! note If you are behind proxy, you can pass the proxy settings to the docker run command via `-e http_proxy=$http_proxy -e https_proxy=$https_proxy`. -[](){ #nginxloadbalancer-nginx-launch-nginx } - ## Launch Nginx ```bash @@ -135,8 +123,6 @@ docker run \ --name nginx-lb nginx-lb:latest ``` -[](){ #nginxloadbalancer-nginx-verify-nginx } - ## Verify That vLLM Servers Are Ready ```bash diff --git a/docs/design/arch_overview.md b/docs/design/arch_overview.md index f1300a73c26c..b67b084a851a 100644 --- a/docs/design/arch_overview.md +++ b/docs/design/arch_overview.md @@ -47,9 +47,9 @@ Here is a sample of `LLM` class usage: print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` -More API details can be found in the [Offline Inference](#offline-inference-api) section of the API docs. +More API details can be found in the [Offline Inference](../api/README.md#offline-inference) section of the API docs. -The code for the `LLM` class can be found in . +The code for the `LLM` class can be found in [vllm/entrypoints/llm.py](../../vllm/entrypoints/llm.py). ### OpenAI-Compatible API Server @@ -60,7 +60,7 @@ This server can be started using the `vllm serve` command. vllm serve ``` -The code for the `vllm` CLI can be found in . +The code for the `vllm` CLI can be found in [vllm/entrypoints/cli/main.py](../../vllm/entrypoints/cli/main.py). Sometimes you may see the API server entrypoint used directly instead of via the `vllm` CLI command. For example: @@ -74,7 +74,7 @@ python -m vllm.entrypoints.openai.api_server --model `python -m vllm.entrypoints.openai.api_server` is deprecated and may become unsupported in a future release. -That code can be found in . +That code can be found in [vllm/entrypoints/openai/api_server.py](../../vllm/entrypoints/openai/api_server.py). More details on the API server can be found in the [OpenAI-Compatible Server](../serving/openai_compatible_server.md) document. @@ -101,7 +101,7 @@ processing. - **Output Processing**: Processes the outputs generated by the model, decoding the token IDs from a language model into human-readable text. -The code for `LLMEngine` can be found in . +The code for `LLMEngine` can be found in [vllm/engine/llm_engine.py](../../vllm/engine/llm_engine.py). ### AsyncLLMEngine @@ -111,9 +111,9 @@ incoming requests. The `AsyncLLMEngine` is designed for online serving, where it can handle multiple concurrent requests and stream outputs to clients. The OpenAI-compatible API server uses the `AsyncLLMEngine`. There is also a demo -API server that serves as a simpler example in . +API server that serves as a simpler example in [vllm/entrypoints/api_server.py](../../vllm/entrypoints/api_server.py). -The code for `AsyncLLMEngine` can be found in . +The code for `AsyncLLMEngine` can be found in [vllm/engine/async_llm_engine.py](../../vllm/engine/async_llm_engine.py). ## Worker diff --git a/docs/design/cuda_graphs.md b/docs/design/cuda_graphs.md index f88a29f6eadd..b56cf61e782c 100644 --- a/docs/design/cuda_graphs.md +++ b/docs/design/cuda_graphs.md @@ -17,7 +17,7 @@ In this document we will discuss the: In this document, we refer to pure decode (`max_query_len=1`) or speculative decode (`max_query_len =1+num_spec_tokens`) as **uniform decode** batches, and the opposite would be **non-uniform** batches (i.e., prefill or mixed prefill-decode batches). !!! note - The following contents are mostly based on the last commit of . + The following contents are mostly based on the last commit of . ## Motivation @@ -92,7 +92,7 @@ where `num_tokens` can be the padded token length, and `uniform_decode` is deter The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item. We are safe to exclude items like `uniform_query_len` because it is a constant at runtime for a certain setup currently. For example, it should be either `1` for a commonly pure decode or `1+num_spec_tokens` for a validation phase of speculative decode. !!! note - The prototype of `BatchDescriptor` may be extended for more general situations in the future, e.g., include more items, like `uniform_query_len` to support multiple different uniform decode lengths settings (), or other modifications needed to support CUDA Graphs for models whose inputs are not necessarily token length aware (for example, some multi-modal inputs). + The prototype of `BatchDescriptor` may be extended for more general situations in the future, e.g., include more items, like `uniform_query_len` to support multiple different uniform decode lengths settings (), or other modifications needed to support CUDA Graphs for models whose inputs are not necessarily token length aware (for example, some multi-modal inputs). ### `CudagraphDispatcher` @@ -106,9 +106,11 @@ The dispatch code looks like: batch_descriptor=BatchDescriptor(num_tokens=num_input_tokens, uniform_decode=...) runtime_mode, batch_descriptor = cudagraphdispatcher.dispatch(batch_descriptor) # execution -with set_forward_context(..., - cudagraph_runtime_mode=runtime_mode, - batch_descriptor=batch_descriptor): +with set_forward_context( + ..., + cudagraph_runtime_mode=runtime_mode, + batch_descriptor=batch_descriptor, +): output = self.model(...) ``` @@ -165,7 +167,7 @@ class AttentionCGSupport(enum.Enum): """NO CUDA Graphs support""" ``` -Suppose we have hybrid attention backends (e.g., in mamba mixer models). In that case, we seek the minimum capability of all backends to determine the final capability of the model, and we might resolve the incompatible CUDA Graphs mode by downgrading the mode to the best fit one. For example, downgrading `FULL` mode to `FULL_AND_PIECEWISE` mode if the minimum capability is `UNIFORM_BATCH`, or `PIECEWISE` mode if the minimum capability is `NEVER` for -O3 compilation level. For the complete fallback policy, please see the code of [initialize_cudagraph_capture][vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_cudagraph_capture]. +Suppose we have hybrid attention backends (e.g., in mamba mixer models). In that case, we seek the minimum capability of all backends to determine the final capability of the model, and we might resolve the incompatible CUDA Graphs mode by downgrading the mode to the best fit one. For example, downgrading `FULL` mode to `FULL_AND_PIECEWISE` mode if the minimum capability is `UNIFORM_BATCH`, or `PIECEWISE` mode if the minimum capability is `NEVER` for -O3 compilation mode. For the complete fallback policy, please see the code for [this][vllm.v1.worker.gpu_model_runner.GPUModelRunner._check_and_update_cudagraph_mode]. The following table lists backends that support full CUDA Graphs at the time of writing. @@ -200,12 +202,12 @@ os.environ.setdefault("VLLM_LOGGING_LEVEL", "DEBUG") import vllm from vllm.config import CUDAGraphMode -compilation_config = {"level": 3, "cudagraph_mode": "FULL_AND_PIECEWISE"} +compilation_config = {"mode": 3, "cudagraph_mode": "FULL_AND_PIECEWISE"} model = vllm.LLM( - model="meta-llama/Llama-3.1-8B-Instruct", - dtype='auto', - compilation_config = compilation_config, - ) + model="meta-llama/Llama-3.1-8B-Instruct", + dtype="auto", + compilation_config=compilation_config, +) sampling_params = vllm.SamplingParams( temperature=0, # greedy decoding max_tokens=1024, diff --git a/docs/design/dbo.md b/docs/design/dbo.md index d92c47c80f95..f2d98ccd063f 100644 --- a/docs/design/dbo.md +++ b/docs/design/dbo.md @@ -34,10 +34,10 @@ To enable the DBO system pass in the `--enable-dbo` argument to your vllm serve * `--dbo-decode-token-threshold` the minimum number of tokens in a decode-only batch required to enable DBO for that batch * `--dbo-prefill-token-threshold` the minimum number of tokens in a batch containing at least one prefill required to enable DBO for that batch -Currently, DBO is only supported with DeepEP, so DeepEP must be installed and the `VLLM_ALL2ALL_BACKEND` environment variable must be set to `deepep_low_latency` if your workload is primarily decode requests, or `deepep_high_throughput` if your workload is primarily prefill requests. +Currently, DBO is only supported with DeepEP, so DeepEP must be installed and the `--all2all-backend` argument must be set to `deepep_low_latency` if your workload is primarily decode requests, or `deepep_high_throughput` if your workload is primarily prefill requests. Below is a command that will spin up a two DP rank server with expert parallelism and DBO enabled. -EX: `VLLM_ALL2ALL_BACKEND=deepep_low_latency vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --trust-remote-code --data-parallel-size 2 --enable-expert-parallel --enable-dbo` +EX: `vllm serve deepseek-ai/DeepSeek-V2-Lite --trust-remote-code --data-parallel-size 2 --enable-expert-parallel --enable-dbo --all2all-backend deepep_low_latency` Note that there must be at least two GPUs visible in `CUDA_VISIBLE_DEVICES` diff --git a/docs/design/fused_moe_modular_kernel.md b/docs/design/fused_moe_modular_kernel.md index ee5701989265..76df0d8d8a38 100644 --- a/docs/design/fused_moe_modular_kernel.md +++ b/docs/design/fused_moe_modular_kernel.md @@ -2,7 +2,7 @@ ## Introduction -FusedMoEModularKernel is implemented [here](gh-file:/vllm/model_executor/layers/fused_moe/modular_kernel.py) +FusedMoEModularKernel is implemented [here](../..//vllm/model_executor/layers/fused_moe/modular_kernel.py) Based on the format of the input activations, FusedMoE implementations are broadly classified into 2 types. @@ -44,7 +44,7 @@ FusedMoEModularKernel splits the FusedMoE operation into 3 parts, The TopK Weight Application and Reduction components happen right after the Unpermute operation and before the All2All Combine. Note that the `FusedMoEPermuteExpertsUnpermute` is responsible for the Unpermute and `FusedMoEPrepareAndFinalize` is responsible for the All2All Combine. There is value in doing the TopK Weight Application and Reduction in the `FusedMoEPermuteExpertsUnpermute`. But some implementations choose to do it `FusedMoEPrepareAndFinalize`. In order to enable this flexibility, we have a TopKWeightAndReduce abstract class. -Please find the implementations of TopKWeightAndReduce [here](gh-file:vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py). +Please find the implementations of TopKWeightAndReduce [here](../../vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py). `FusedMoEPrepareAndFinalize::finalize()` method accepts a `TopKWeightAndReduce` argument that is invoked inside the method. The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEPermuteExpertsUnpermute` and `FusedMoEPerpareAndFinalize` implementations to determine where the TopK Weight Application and Reduction happens. @@ -138,7 +138,7 @@ Typically a FusedMoEPrepareAndFinalize type is backed by an All2All Dispatch & C #### Step 1: Add an All2All manager -The purpose of the All2All Manager is to set up the All2All kernel implementations. The `FusedMoEPrepareAndFinalize` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](gh-file:vllm/distributed/device_communicators/all2all.py). +The purpose of the All2All Manager is to set up the All2All kernel implementations. The `FusedMoEPrepareAndFinalize` implementations typically fetch a kernel-implementation "handle" from the All2All Manager to invoke the Dispatch and Combine functions. Please look at the All2All Manager implementations [here](../../vllm/distributed/device_communicators/all2all.py). #### Step 2: Add a FusedMoEPrepareAndFinalize Type @@ -213,29 +213,29 @@ Please take a look at [init_prepare_finalize](https://github.com/vllm-project/vl ### How To Unit Test -We have `FusedMoEModularKernel` unit tests at [test_modular_kernel_combinations.py](gh-file:tests/kernels/moe/test_modular_kernel_combinations.py). +We have `FusedMoEModularKernel` unit tests at [test_modular_kernel_combinations.py](../../tests/kernels/moe/test_modular_kernel_combinations.py). The unit test iterates through all combinations of `FusedMoEPrepareAndFinalize` and `FusedMoEPremuteExpertsUnpermute` types and if they are compatible, runs some correctness tests. If you are adding some `FusedMoEPrepareAndFinalize` / `FusedMoEPermuteExpertsUnpermute` implementations, -1. Add the implementation type to `MK_ALL_PREPARE_FINALIZE_TYPES` and `MK_FUSED_EXPERT_TYPES` in [mk_objects.py](gh-file:tests/kernels/moe/modular_kernel_tools/mk_objects.py) respectively. +1. Add the implementation type to `MK_ALL_PREPARE_FINALIZE_TYPES` and `MK_FUSED_EXPERT_TYPES` in [mk_objects.py](../../tests/kernels/moe/modular_kernel_tools/mk_objects.py) respectively. 2. Update `Config::is_batched_prepare_finalize()`, `Config::is_batched_fused_experts()`, `Config::is_standard_fused_experts()`, `Config::is_fe_16bit_supported()`, `Config::is_fe_fp8_supported()`, `Config::is_fe_block_fp8_supported()`, -`Config::is_fe_supports_chunking()` methods in [/tests/kernels/moe/modular_kernel_tools/common.py](gh-file:tests/kernels/moe/modular_kernel_tools/common.py) +`Config::is_fe_supports_chunking()` methods in [/tests/kernels/moe/modular_kernel_tools/common.py](../../tests/kernels/moe/modular_kernel_tools/common.py) Doing this will add the new implementation to the test suite. ### How To Check `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` Compatibility -The unit test file [test_modular_kernel_combinations.py](gh-file:tests/kernels/moe/test_modular_kernel_combinations.py) can also be executed as a standalone script. +The unit test file [test_modular_kernel_combinations.py](../../tests/kernels/moe/test_modular_kernel_combinations.py) can also be executed as a standalone script. Example: `python3 -m tests.kernels.moe.test_modular_kernel_combinations --pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts` As a side effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked with incompatible types, the script will error. ### How To Profile -Please take a look at [profile_modular_kernel.py](gh-file:tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py) +Please take a look at [profile_modular_kernel.py](../../tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py) The script can be used to generate Torch traces for a single `FusedMoEModularKernel::forward()` call for any compatible `FusedMoEPrepareAndFinalize` and `FusedMoEPermuteExpertsUnpermute` types. Example: `python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel --pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts` diff --git a/docs/design/io_processor_plugins.md b/docs/design/io_processor_plugins.md index e70ee4a076e5..fb64a7bb9c8f 100644 --- a/docs/design/io_processor_plugins.md +++ b/docs/design/io_processor_plugins.md @@ -6,14 +6,13 @@ When performing an inference with IO Processor plugins, the prompt type is defin ## Writing an IO Processor Plugin -IO Processor plugins implement the `IOProcessor` interface (): +IO Processor plugins implement the [`IOProcessor`][vllm.plugins.io_processors.interface.IOProcessor] interface: ```python -IOProcessorInput = TypeVar('IOProcessorInput') -IOProcessorOutput = TypeVar('IOProcessorOutput') +IOProcessorInput = TypeVar("IOProcessorInput") +IOProcessorOutput = TypeVar("IOProcessorOutput") class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): - def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config @@ -21,52 +20,66 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): def pre_process( self, prompt: IOProcessorInput, - request_id: Optional[str] = None, + request_id: str | None = None, **kwargs, - ) -> Union[PromptType, Sequence[PromptType]]: + ) -> PromptType | Sequence[PromptType]: raise NotImplementedError async def pre_process_async( self, prompt: IOProcessorInput, - request_id: Optional[str] = None, + request_id: str | None = None, **kwargs, - ) -> Union[PromptType, Sequence[PromptType]]: + ) -> PromptType | Sequence[PromptType]: return self.pre_process(prompt, request_id, **kwargs) @abstractmethod - def post_process(self, - model_output: Sequence[PoolingRequestOutput], - request_id: Optional[str] = None, - **kwargs) -> IOProcessorOutput: + def post_process( + self, + model_output: Sequence[PoolingRequestOutput], + request_id: str | None = None, + **kwargs, + ) -> IOProcessorOutput: raise NotImplementedError async def post_process_async( self, model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]], - request_id: Optional[str] = None, + request_id: str | None = None, **kwargs, ) -> IOProcessorOutput: - collected_output = [item async for i, item in model_output] + # We cannot guarantee outputs are returned in the same order they were + # fed to vLLM. + # Let's sort them by id before post_processing + sorted_output = sorted( + [(i, item) async for i, item in model_output], key=lambda output: output[0] + ) + collected_output = [output[1] for output in sorted_output] return self.post_process(collected_output, request_id, **kwargs) @abstractmethod def parse_request(self, request: Any) -> IOProcessorInput: raise NotImplementedError + def validate_or_generate_params( + self, params: SamplingParams | PoolingParams | None = None + ) -> SamplingParams | PoolingParams: + return params or PoolingParams() + @abstractmethod def output_to_response( - self, plugin_output: IOProcessorOutput) -> IOProcessorResponse: + self, plugin_output: IOProcessorOutput + ) -> IOProcessorResponse: raise NotImplementedError ``` The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods. The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference. The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output. +The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters. +The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/openai/serving_pooling.py). -The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/io_processor_pooling` serving endpoint is available here . - -An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin). Please, also refer to our online () and offline () inference examples. +An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/prithvi_geospatial_mae.py](../../examples/online_serving/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py)) inference examples. ## Using an IO Processor plugin diff --git a/docs/design/logits_processors.md b/docs/design/logits_processors.md index 20d78ca3aae2..da61d2a85e46 100644 --- a/docs/design/logits_processors.md +++ b/docs/design/logits_processors.md @@ -174,7 +174,7 @@ The previous sections alluded to the interfaces which vLLM logits processors mus from collections.abc import Sequence from dataclasses import dataclass from enum import Enum, auto - from typing import TYPE_CHECKING, Optional + from typing import TYPE_CHECKING import torch @@ -244,7 +244,7 @@ The previous sections alluded to the interfaces which vLLM logits processors mus @abstractmethod def update_state( self, - batch_update: Optional["BatchUpdate"], + batch_update: "BatchUpdate" | None, ) -> None: """Called when there are new output tokens, prior to each forward pass. @@ -274,7 +274,7 @@ A vLLM logits processor must subclass `LogitsProcessor` and define (at minimum) * Return `True` if the logits processor is argmax invariant (never changes what is the highest-logit-value token ID for a given request), `False` if the logits processor may modify argmax * `is_argmax_invariant()` is evaluated once at startup; if `True`, vLLM will skip applying this logits processor in a given step when all requests use greedy sampling -* `update_state(self, batch_update: Optional["BatchUpdate"]) -> None`: +* `update_state(self, batch_update: "BatchUpdate" | None) -> None`: * Consume a `BatchUpdate` data structure representing persistent batch state changes at the beginning of the current engine step * Use the `BatchUpdate` members to update logits processor internal state * **Note:** batch update data structure may be `None`, signaling no change to the batch constituents. In this case, the LogitsProcessor might still want to update its state based on the updated `output_token_ids` lists that it could have retained when they were added. diff --git a/docs/design/metrics.md b/docs/design/metrics.md index 90b2fd32f297..313c9aaebd26 100644 --- a/docs/design/metrics.md +++ b/docs/design/metrics.md @@ -1,12 +1,12 @@ # Metrics -Ensure the v1 LLM Engine exposes a superset of the metrics available in v0. +vLLM exposes a rich set of metrics to support observability and capacity planning for the V1 engine. ## Objectives -- Achieve parity of metrics between v0 and v1. -- The priority use case is accessing these metrics via Prometheus, as this is what we expect to be used in production environments. -- Logging support (i.e. printing metrics to the info log) is provided for more ad-hoc testing, debugging, development, and exploratory use cases. +- Provide comprehensive coverage of engine and request level metrics to aid production monitoring. +- Prioritize Prometheus integrations, as this is what we expect to be used in production environments. +- Offer logging support (i.e. printing metrics to the info log) for ad-hoc testing, debugging, development, and exploratory use cases. ## Background @@ -17,51 +17,42 @@ Metrics in vLLM can be categorized as follows: The mental model is that server-level metrics help explain the values of request-level metrics. -### v0 Metrics - -In v0, the following metrics are exposed via a Prometheus-compatible `/metrics` endpoint using the `vllm:` prefix: - -- `vllm:num_requests_running` (Gauge) -- `vllm:num_requests_swapped` (Gauge) -- `vllm:num_requests_waiting` (Gauge) -- `vllm:gpu_cache_usage_perc` (Gauge) -- `vllm:cpu_cache_usage_perc` (Gauge) -- `vllm:gpu_prefix_cache_hit_rate` (Gauge) -- `vllm:cpu_prefix_cache_hit_rate` (Gauge) -- `vllm:prompt_tokens_total` (Counter) -- `vllm:generation_tokens_total` (Counter) -- `vllm:request_success_total` (Counter) -- `vllm:request_prompt_tokens` (Histogram) -- `vllm:request_generation_tokens` (Histogram) -- `vllm:time_to_first_token_seconds` (Histogram) -- `vllm:time_per_output_token_seconds` (Histogram) -- `vllm:e2e_request_latency_seconds` (Histogram) -- `vllm:request_queue_time_seconds` (Histogram) -- `vllm:request_inference_time_seconds` (Histogram) -- `vllm:request_prefill_time_seconds` (Histogram) -- `vllm:request_decode_time_seconds` (Histogram) -- `vllm:request_max_num_generation_tokens` (Histogram) -- `vllm:num_preemptions_total` (Counter) -- `vllm:cache_config_info` (Gauge) -- `vllm:lora_requests_info` (Gauge) -- `vllm:tokens_total` (Counter) -- `vllm:iteration_tokens_total` (Histogram) -- `vllm:time_in_queue_requests` (Histogram) -- `vllm:model_forward_time_milliseconds` (Histogram) -- `vllm:model_execute_time_milliseconds` (Histogram) -- `vllm:request_params_n` (Histogram) -- `vllm:request_params_max_tokens` (Histogram) -- `vllm:spec_decode_draft_acceptance_rate` (Gauge) -- `vllm:spec_decode_efficiency` (Gauge) -- `vllm:spec_decode_num_accepted_tokens_total` (Counter) -- `vllm:spec_decode_num_draft_tokens_total` (Counter) -- `vllm:spec_decode_num_emitted_tokens_total` (Counter) +### Metrics Overview + +### v1 Metrics + +In v1, the following metrics are exposed via a Prometheus-compatible `/metrics` endpoint using the `vllm:` prefix: + +- `vllm:num_requests_running` (Gauge) - Number of requests currently running. +- `vllm:num_requests_waiting` (Gauge) - Number of requests currently waiting. +- `vllm:kv_cache_usage_perc` (Gauge) - Fraction of used KV cache blocks (0–1). +- `vllm:prefix_cache_queries` (Counter) - Number of prefix cache queries. +- `vllm:prefix_cache_hits` (Counter) - Number of prefix cache hits. +- `vllm:mm_cache_queries` (Counter) - (For multimodal models) Number of multimodal cache queries. +- `vllm:mm_cache_hits` (Counter) - (For multimodal models) Number of multimodal cache hits. +- `vllm:num_preemptions_total` (Counter) - Number of preemptions. +- `vllm:prompt_tokens_total` (Counter) - Total number of prompt tokens processed. +- `vllm:generation_tokens_total` (Counter) - Total number of generated tokens. +- `vllm:iteration_tokens_total` (Histogram) - Histogram of tokens processed in each engine step. +- `vllm:cache_config_info` (Gauge) - Information about the cache configuration. +- `vllm:request_success_total` (Counter) - Number of finished requests (by finish reason). +- `vllm:request_prompt_tokens` (Histogram) - Histogram of input prompt token counts. +- `vllm:request_generation_tokens` (Histogram) - Histogram of generation token counts. +- `vllm:request_params_n` (Histogram) - Histogram of request parameter n. +- `vllm:request_params_max_tokens` - (Histogram) - Histogram of max_tokens parameter in requests. +- `vllm:time_to_first_token_seconds` (Histogram) - Time to first token (TTFT). +- `vllm:inter_token_latency_seconds` (Histogram) - Inter-token latency. +- `vllm:e2e_request_latency_seconds` (Histogram) - End-to-end request latency. +- `vllm:request_queue_time_seconds` (Histogram) - Time spent in the queue. +- `vllm:request_inference_time_seconds` (Histogram) - Request inference time. +- `vllm:request_prefill_time_seconds` (Histogram) - Request prefill time. +- `vllm:request_decode_time_seconds` (Histogram) - Request decode time. These are documented under [Inferencing and Serving -> Production Metrics](../usage/metrics.md). ### Grafana Dashboard -vLLM also provides [a reference example](../examples/online_serving/prometheus_grafana.md) for how to collect and store these metrics using Prometheus and visualize them using a Grafana dashboard. +vLLM also provides [a reference example](../../examples/online_serving/prometheus_grafana/README.md) for how to collect and store these metrics using Prometheus and visualize them using a Grafana dashboard. The subset of metrics exposed in the Grafana dashboard gives us an indication of which metrics are especially important: @@ -80,13 +71,13 @@ The subset of metrics exposed in the Grafana dashboard gives us an indication of - `vllm:request_decode_time_seconds` - Requests decode time. - `vllm:request_max_num_generation_tokens` - Max generation tokens in a sequence group. -See [the PR which added this Dashboard](gh-pr:2316) for interesting and useful background on the choices made here. +See [the PR which added this Dashboard](https://github.com/vllm-project/vllm/pull/2316) for interesting and useful background on the choices made here. ### Prometheus Client Library -Prometheus support was initially added [using the aioprometheus library](gh-pr:1890), but a switch was made quickly to [prometheus_client](gh-pr:2730). The rationale is discussed in both linked PRs. +Prometheus support was initially added [using the aioprometheus library](https://github.com/vllm-project/vllm/pull/1890), but a switch was made quickly to [prometheus_client](https://github.com/vllm-project/vllm/pull/2730). The rationale is discussed in both linked PRs. -With the switch to `aioprometheus`, we lost a `MetricsMiddleware` to track HTTP metrics, but this was reinstated [using prometheus_fastapi_instrumentator](gh-pr:15657): +During those migrations we briefly lost a `MetricsMiddleware` to track HTTP metrics, but this was reinstated [using prometheus_fastapi_instrumentator](https://github.com/vllm-project/vllm/pull/15657): ```bash $ curl http://0.0.0.0:8000/metrics 2>/dev/null | grep -P '^http_(?!.*(_bucket|_created|_sum)).*' @@ -99,7 +90,9 @@ http_request_duration_seconds_count{handler="/v1/completions",method="POST"} 201 ### Multi-process Mode -In v0, metrics are collected in the engine core process and we use multiprocess mode to make them available in the API server process. See . +Historically, metrics were collected in the engine core process and multiprocess mode was used to make them available in the API server process. See . + +More recently, metrics are collected in the API server process and multiprocess mode is only used when `--api-server-count > 1`. See and details on [API server scale-out](../serving/data_parallel_deployment.md#internal-load-balancing). ### Built in Python/Process Metrics @@ -116,41 +109,37 @@ The following metrics are supported by default by `prometheus_client`, but they - `process_open_fds` - `process_max_fds` -This is relevant because if we move away from multiprocess mode in v1, -we get these back. However, it's questionable how relevant these are -if they don't aggregate these stats for all processes that make up a -vLLM instance. +Therefore, these metrics are unavailable when `--api-server-count > 1`. It's questionable how relevant these are since they do not aggregate these stats for all processes that make up a vLLM instance. -### v0 PRs and Issues +## Metrics Design -For background, these are some of the relevant PRs which added the v0 metrics: +The ["Even Better Observability"](https://github.com/vllm-project/vllm/issues/3616) feature where was where much of the metrics design was planned. For example, see where [a detailed roadmap was laid out](https://github.com/vllm-project/vllm/issues/3616#issuecomment-2030858781). -- -- -- -- -- +### Legacy PRs -Also note the ["Even Better Observability"](gh-issue:3616) feature where e.g. [a detailed roadmap was laid out](gh-issue:3616#issuecomment-2030858781). +To help understand the background to the metrics design, here are some of the relevant PRs which added the original, now legacy, metrics: -## v1 Design +- +- +- +- +- -### v1 PRs +### Metrics Implementation PRs -For background, here are the relevant v1 PRs relating to the v1 -metrics issue : +For background, here are the relevant PRs relating to the metrics implementation : -- -- -- -- -- -- -- -- -- -- -- +- +- +- +- +- +- +- +- +- +- +- ### Metrics Collection @@ -394,15 +383,14 @@ distinguish between per-adapter counts. This should be revisited. Note that `multiprocess_mode="livemostrecent"` is used - the most recent metric is used, but only from currently running processes. -This was added in and there is +This was added in and there is [at least one known user](https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/54). -If we revisit this design and deprecate the old metric, we should reduce -the need for a significant deprecation period by making the change in -v0 also and asking this project to move to the new metric. +If we revisit this design and deprecate the old metric, we should +coordinate with downstream users so they can migrate before the removal. ### Prefix Cache metrics -The discussion in about adding prefix cache metrics yielded +The discussion in about adding prefix cache metrics yielded some interesting points which may be relevant to how we approach future metrics. @@ -439,8 +427,8 @@ suddenly (from their perspective) when it is removed, even if there is an equivalent metric for them to use. As an example, see how `vllm:avg_prompt_throughput_toks_per_s` was -[deprecated](gh-pr:2764) (with a comment in the code), -[removed](gh-pr:12383), and then [noticed by a user](gh-issue:13218). +[deprecated](https://github.com/vllm-project/vllm/pull/2764) (with a comment in the code), +[removed](https://github.com/vllm-project/vllm/pull/12383), and then [noticed by a user](https://github.com/vllm-project/vllm/issues/13218). In general: @@ -460,20 +448,20 @@ the project-wide deprecation policy. ### Unimplemented - `vllm:tokens_total` -Added by , but apparently never implemented. This can just be +Added by , but apparently never implemented. This can just be removed. ### Duplicated - Queue Time The `vllm:time_in_queue_requests` Histogram metric was added by - and its calculation is: + and its calculation is: ```python self.metrics.first_scheduled_time = now self.metrics.time_in_queue = now - self.metrics.arrival_time ``` -Two weeks later, added `vllm:request_queue_time_seconds` leaving +Two weeks later, added `vllm:request_queue_time_seconds` leaving us with: ```python @@ -491,7 +479,7 @@ if seq_group.is_finished(): This seems duplicative, and one of them should be removed. The latter is used by the Grafana dashboard, so we should deprecate or remove the -former from v0. +former. ### Prefix Cache Hit Rate @@ -500,7 +488,7 @@ See above - we now expose 'queries' and 'hits' counters rather than a ### KV Cache Offloading -Two v0 metrics relate to a "swapped" preemption mode that is no +Two legacy metrics relate to a "swapped" preemption mode that is no longer relevant in v1: - `vllm:num_requests_swapped` @@ -511,7 +499,7 @@ cache to complete other requests), we swap kv cache blocks out to CPU memory. This is also known as "KV cache offloading" and is configured with `--swap-space` and `--preemption-mode`. -In v0, [vLLM has long supported beam search](gh-issue:6226). The +Historically, [vLLM has long supported beam search](https://github.com/vllm-project/vllm/issues/6226). The SequenceGroup encapsulated the idea of N Sequences which all shared the same prompt kv blocks. This enabled KV cache block sharing between requests, and copy-on-write to do branching. CPU @@ -524,7 +512,7 @@ and the part of the prompt that was evicted can be recomputed. SequenceGroup was removed in V1, although a replacement will be required for "parallel sampling" (`n>1`). -[Beam search was moved out of the core (in V0)](gh-issue:8306). There was a +[Beam search was moved out of the core](https://github.com/vllm-project/vllm/issues/8306). There was a lot of complex code for a very uncommon feature. In V1, with prefix caching being better (zero over head) and therefore @@ -535,11 +523,11 @@ better. ### Parallel Sampling -Some v0 metrics are only relevant in the context of "parallel +Some legacy metrics are only relevant in the context of "parallel sampling". This is where the `n` parameter in a request is used to request multiple completions from the same prompt. -As part of adding parallel sampling support in , we should +As part of adding parallel sampling support in , we should also add these metrics. - `vllm:request_params_n` (Histogram) @@ -554,7 +542,7 @@ also add these metrics. ### Speculative Decoding -Some v0 metrics are specific to "speculative decoding". This is where +Some legacy metrics are specific to "speculative decoding". This is where we generate candidate tokens using a faster, approximate method or model and then validate those tokens with the larger model. @@ -564,9 +552,9 @@ model and then validate those tokens with the larger model. - `vllm:spec_decode_num_draft_tokens_total` (Counter) - `vllm:spec_decode_num_emitted_tokens_total` (Counter) -There is a PR under review () to add "prompt lookup (ngram)" +There is a PR under review () to add "prompt lookup (ngram)" speculative decoding to v1. Other techniques will follow. We should -revisit the v0 metrics in this context. +revisit these metrics in this context. !!! note We should probably expose acceptance rate as separate accepted @@ -585,7 +573,7 @@ see: - [Standardizing Large Model Server Metrics in Kubernetes](https://docs.google.com/document/d/1SpSp1E6moa4HSrJnS4x3NpLuj88sMXr2tbofKlzTZpk) - [Benchmarking LLM Workloads for Performance Evaluation and Autoscaling in Kubernetes](https://docs.google.com/document/d/1k4Q4X14hW4vftElIuYGDu5KDe2LtV1XammoG-Xi3bbQ) - [Inference Perf](https://github.com/kubernetes-sigs/wg-serving/tree/main/proposals/013-inference-perf) -- and . +- and . This is a non-trivial topic. Consider this comment from Rob: @@ -639,7 +627,7 @@ metrics are often relatively straightforward to add: metrics are usually of very limited use unless they can be enabled by default and in production. 3. They have an impact on development and maintenance of the - project. Every metric added to v0 has made this v1 effort more + project. Every metric added over time has made this effort more time-consuming, and perhaps not all metrics justify this ongoing investment in their maintenance. @@ -650,24 +638,24 @@ performance and health. Tracing, on the other hand, tracks individual requests as they move through different services and components. Both fall under the more general heading of "Observability". -v0 has support for OpenTelemetry tracing: +vLLM has support for OpenTelemetry tracing: -- Added by +- Added by and reinstated by - Configured with `--oltp-traces-endpoint` and `--collect-detailed-traces` - [OpenTelemetry blog post](https://opentelemetry.io/blog/2024/llm-observability/) - [User-facing docs](../examples/online_serving/opentelemetry.md) - [Blog post](https://medium.com/@ronen.schaffer/follow-the-trail-supercharging-vllm-with-opentelemetry-distributed-tracing-aa655229b46f) - [IBM product docs](https://www.ibm.com/docs/en/instana-observability/current?topic=mgaa-monitoring-large-language-models-llms-vllm-public-preview) - + OpenTelemetry has a [Gen AI Working Group](https://github.com/open-telemetry/community/blob/main/projects/gen-ai.md). -Since metrics is a big enough topic on its own, we are going to tackle -the topic of tracing in v1 separately. +Since metrics is a big enough topic on its own, we consider the topic +of tracing to be quite separate from metrics. ### OpenTelemetry Model Forward vs Execute Time -In v0, we have the following two metrics: +The current implementation exposes the following two metrics: - `vllm:model_forward_time_milliseconds` (Histogram) - The time spent in the model forward pass when this request was in the batch. @@ -683,7 +671,7 @@ documentation for this option states: > use of possibly costly and or blocking operations and hence might > have a performance impact. -The metrics were added by and who up in an OpenTelemetry trace +The metrics were added by and who up in an OpenTelemetry trace as: ```text diff --git a/docs/design/mm_processing.md b/docs/design/mm_processing.md index 1e9b6ad6e821..ee56ac5b98ef 100644 --- a/docs/design/mm_processing.md +++ b/docs/design/mm_processing.md @@ -1,6 +1,6 @@ # Multi-Modal Data Processing -To enable various optimizations in vLLM such as [chunked prefill][chunked-prefill] and [prefix caching](../features/automatic_prefix_caching.md), we use [BaseMultiModalProcessor][vllm.multimodal.processing.BaseMultiModalProcessor] to provide the correspondence between placeholder feature tokens (e.g. ``) and multi-modal inputs (e.g. the raw input image) based on the outputs of HF processor. +To enable various optimizations in vLLM such as [chunked prefill](../configuration/optimization.md#chunked-prefill) and [prefix caching](../features/automatic_prefix_caching.md), we use [BaseMultiModalProcessor][vllm.multimodal.processing.BaseMultiModalProcessor] to provide the correspondence between placeholder feature tokens (e.g. ``) and multi-modal inputs (e.g. the raw input image) based on the outputs of HF processor. Here are the main features of [BaseMultiModalProcessor][vllm.multimodal.processing.BaseMultiModalProcessor]: @@ -41,14 +41,10 @@ While HF processors support text + multi-modal inputs natively, this is not so f Moreover, since the tokenized text has not passed through the HF processor, we have to apply Step 3 by ourselves to keep the output tokens and multi-modal data consistent with each other. -[](){ #mm-dummy-text } - ### Dummy text We work around the first issue by requiring each model to define how to generate dummy text based on the number of multi-modal inputs, via [get_dummy_text][vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_text]. This lets us generate dummy text corresponding to the multi-modal inputs and input them together to obtain the processed multi-modal data. -[](){ #mm-automatic-prompt-updating } - ### Automatic prompt updating We address the second issue by implementing model-agnostic code in @@ -60,8 +56,8 @@ With the help of dummy text and automatic prompt updating, our multi-modal proce ## Processor Output Caching -Some HF processors, such as the one for Qwen2-VL, are [very slow](gh-issue:9238). To alleviate this problem, we cache the multi-modal outputs of HF processor to avoid processing the same multi-modal input (e.g. image) again. +Some HF processors, such as the one for Qwen2-VL, are [very slow](https://github.com/vllm-project/vllm/issues/9238). To alleviate this problem, we cache the multi-modal outputs of HF processor to avoid processing the same multi-modal input (e.g. image) again. When new data is passed in, we first check which items are in the cache, and which ones are missing. The missing items are passed into the HF processor in a single batch and cached, before being merged with the existing items in the cache. -Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text][mm-dummy-text] to avoid HF errors. Since this skips HF's prompt updating code, we apply [automatic prompt updating][mm-automatic-prompt-updating] afterwards to keep the output tokens and multi-modal data consistent with each other. +Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#dummy-text) to avoid HF errors. Since this skips HF's prompt updating code, we apply [automatic prompt updating](#automatic-prompt-updating) afterwards to keep the output tokens and multi-modal data consistent with each other. diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 0831c5bc790d..633e23eea33e 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -92,8 +92,8 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels | flashinfer | standard | nvfp4,
fp8 | T | 5 | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],
[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | | gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],
[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | | deep gemm+triton2 | standard,
batched | all1 | G(128),A,T | silu, gelu | 6 | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],
[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] | -| marlin | standard | 3 | 3 | silu,
swigluoai | Y | N | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe] | -| marlin experts | standard | N/A | N/A | silu,
swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts] | +| marlin | standard | 3 | 3 | silu,
swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],
[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],
[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | +| marlin experts | standard,
batched | N/A | N/A | silu,
swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],
[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | | trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | | pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] | | iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] | @@ -115,6 +115,6 @@ The following table shows "families" of modular kernels that are intended to wor | backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses | |----------------------------------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------| -| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,
`TritonExperts`,
`TritonOrDeepGemmExperts`,
`CutlassExpertsFp8`,
`MarlinExperts` | -| deepep_low_latency,
pplx | `DeepEPLLPrepareAndFinalize`,
`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,
`BatchedTritonExperts`,
`BatchedTritonOrDeepGemmExperts`,
`CutlassBatchedExpertsFp8`| -| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` | +| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,
`TritonExperts`,
`TritonOrDeepGemmExperts`,
`CutlassExpertsFp8`,
`MarlinExperts` | +| deepep_low_latency,
pplx | `DeepEPLLPrepareAndFinalize`,
`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,
`BatchedTritonExperts`,
`BatchedTritonOrDeepGemmExperts`,
`CutlassBatchedExpertsFp8`,
`BatchedMarlinExperts`| +| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` | diff --git a/docs/design/multiprocessing.md b/docs/design/multiprocessing.md index 6e92b20d267b..d6bd92278829 100644 --- a/docs/design/multiprocessing.md +++ b/docs/design/multiprocessing.md @@ -2,7 +2,7 @@ ## Debugging -Please see the [Troubleshooting][troubleshooting-python-multiprocessing] +Please see the [Troubleshooting](../usage/troubleshooting.md#python-multiprocessing) page for information on known issues and how to solve them. ## Introduction @@ -82,7 +82,7 @@ There are other miscellaneous places hard-coding the use of `spawn`: Related PRs: -- +- ## Prior State in v1 diff --git a/docs/design/plugin_system.md b/docs/design/plugin_system.md index a384c6289f4f..dc2f7c4aed3c 100644 --- a/docs/design/plugin_system.md +++ b/docs/design/plugin_system.md @@ -41,7 +41,7 @@ Every plugin has three parts: 1. **Plugin group**: The name of the entry point group. vLLM uses the entry point group `vllm.general_plugins` to register general plugins. This is the key of `entry_points` in the `setup.py` file. Always use `vllm.general_plugins` for vLLM's general plugins. 2. **Plugin name**: The name of the plugin. This is the value in the dictionary of the `entry_points` dictionary. In the example above, the plugin name is `register_dummy_model`. Plugins can be filtered by their names using the `VLLM_PLUGINS` environment variable. To load only a specific plugin, set `VLLM_PLUGINS` to the plugin name. -3. **Plugin value**: The fully qualified name of the function to register in the plugin system. In the example above, the plugin value is `vllm_add_dummy_model:register`, which refers to a function named `register` in the `vllm_add_dummy_model` module. +3. **Plugin value**: The fully qualified name of the function or module to register in the plugin system. In the example above, the plugin value is `vllm_add_dummy_model:register`, which refers to a function named `register` in the `vllm_add_dummy_model` module. ## Types of supported plugins @@ -51,6 +51,8 @@ Every plugin has three parts: - **IO Processor plugins** (with group name `vllm.io_processor_plugins`): The primary use case for these plugins is to register custom pre/post processing of the model prompt and model output for pooling models. The plugin function returns the IOProcessor's class fully qualified name. +- **Stat logger plugins** (with group name `vllm.stat_logger_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree loggers into vLLM. The entry point should be a class that subclasses StatLoggerBase. + ## Guidelines for Writing Plugins - **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes. diff --git a/docs/design/prefix_caching.md b/docs/design/prefix_caching.md index 9941837bf165..bd4070f381d8 100644 --- a/docs/design/prefix_caching.md +++ b/docs/design/prefix_caching.md @@ -112,8 +112,8 @@ class KVCacheBlock: ref_cnt: int # The pointers to form a doubly linked list for the free queue. - prev_free_block: Optional["KVCacheBlock"] = None - next_free_block: Optional["KVCacheBlock"] = None + prev_free_block: "KVCacheBlock | None" = None + next_free_block: "KVCacheBlock | None" = None ``` There are two design points to highlight: @@ -213,22 +213,22 @@ In this example, we assume the block size is 4 (each block can cache 4 tokens), ![Example Time 1](../assets/design/prefix_caching/example-time-1.png) -**Time 3: Request 0 makes the block 3 full and asks for a new block to keep decoding.** We cache block 3 and allocate block 4. +**Time 2: Request 0 makes the block 3 full and asks for a new block to keep decoding.** We cache block 3 and allocate block 4. -![Example Time 3](../assets/design/prefix_caching/example-time-3.png) +![Example Time 2](../assets/design/prefix_caching/example-time-3.png) -**Time 4: Request 1 comes in with the 14 prompt tokens, where the first 10 tokens are the same as request 0.** We can see that only the first 2 blocks (8 tokens) hit the cache, because the 3rd block only matches 2 of 4 tokens. +**Time 3: Request 1 comes in with the 14 prompt tokens, where the first 10 tokens are the same as request 0.** We can see that only the first 2 blocks (8 tokens) hit the cache, because the 3rd block only matches 2 of 4 tokens. -![Example Time 4](../assets/design/prefix_caching/example-time-4.png) +![Example Time 3](../assets/design/prefix_caching/example-time-4.png) -**Time 5: Request 0 is finished and free.** Blocks 2, 3 and 4 are added to the free queue in the reverse order (but block 2 and 3 are still cached). Block 0 and 1 are not added to the free queue because they are being used by Request 1. +**Time 4: Request 0 is finished and free.** Blocks 2, 3 and 4 are added to the free queue in the reverse order (but block 2 and 3 are still cached). Block 0 and 1 are not added to the free queue because they are being used by Request 1. -![Example Time 5](../assets/design/prefix_caching/example-time-5.png) +![Example Time 4](../assets/design/prefix_caching/example-time-5.png) -**Time 6: Request 1 is finished and free.** +**Time 5: Request 1 is finished and free.** -![Example Time 6](../assets/design/prefix_caching/example-time-6.png) +![Example Time 5](../assets/design/prefix_caching/example-time-6.png) -**Time 7: Request 2 comes in with the 29 prompt tokens, where the first 12 tokens are the same as request 0\.** Note that even the block order in the free queue was `7 - 8 - 9 - 4 - 3 - 2 - 6 - 5 - 1 - 0`, the cache hit blocks (i.e., 0, 1, 2) are touched and removed from the queue before allocation, so the free queue becomes `7 - 8 - 9 - 4 - 3 - 6 - 5`. As a result, the allocated blocks are 0 (cached), 1 (cached), 2 (cached), 7, 8, 9, 4, 3 (evicted). +**Time 6: Request 2 comes in with the 29 prompt tokens, where the first 12 tokens are the same as request 0\.** Note that even the block order in the free queue was `7 - 8 - 9 - 4 - 3 - 2 - 6 - 5 - 1 - 0`, the cache hit blocks (i.e., 0, 1, 2) are touched and removed from the queue before allocation, so the free queue becomes `7 - 8 - 9 - 4 - 3 - 6 - 5`. As a result, the allocated blocks are 0 (cached), 1 (cached), 2 (cached), 7, 8, 9, 4, 3 (evicted). -![Example Time 7](../assets/design/prefix_caching/example-time-7.png) +![Example Time 6](../assets/design/prefix_caching/example-time-7.png) diff --git a/docs/design/torch_compile.md b/docs/design/torch_compile.md index 32a4efef71fb..5a3ca2de8219 100644 --- a/docs/design/torch_compile.md +++ b/docs/design/torch_compile.md @@ -19,8 +19,8 @@ vLLM will take all the available factors into consideration, and decide a direct The factors considered include: -- All the related configs (see the `compute_hash` functions in their respective configs in the [config folder](gh-file:vllm/config)) -- PyTorch configs (see the `compute_hash` functions in the [compiler_interface.py](gh-file:vllm/compilation/compiler_interface.py)) +- All the related configs (see the `compute_hash` functions in their respective configs in the [config folder](../../vllm/config)) +- PyTorch configs (see the `compute_hash` functions in the [compiler_interface.py](../../vllm/compilation/compiler_interface.py)) - The model's forward function and the relevant functions called by the forward function (see below) With all these factors taken into consideration, usually we can guarantee that the cache is safe to use, and will not cause any unexpected behavior. Therefore, the cache is enabled by default. If you want to debug the compilation process, or if you suspect the cache is causing some issues, you can disable it by setting the environment variable `VLLM_DISABLE_COMPILE_CACHE=1`. diff --git a/docs/features/README.md b/docs/features/README.md index 05ce0b57a9fc..ad9de9ff8f36 100644 --- a/docs/features/README.md +++ b/docs/features/README.md @@ -36,45 +36,43 @@ th:not(:first-child) { } -| Feature | [CP][chunked-prefill] | [APC](automatic_prefix_caching.md) | [LoRA](lora.md) | [SD](spec_decode.md) | CUDA graph | [pooling](../models/pooling_models.md) | enc-dec | logP | prmpt logP | async output | multi-step | mm | best-of | beam-search | [prompt-embeds](prompt_embeds.md) | +| Feature | [CP](../configuration/optimization.md#chunked-prefill) | [APC](automatic_prefix_caching.md) | [LoRA](lora.md) | [SD](spec_decode.md) | CUDA graph | [pooling](../models/pooling_models.md) | enc-dec | logP | prmpt logP | async output | multi-step | mm | best-of | beam-search | [prompt-embeds](prompt_embeds.md) | |---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| -| [CP][chunked-prefill] | ✅ | | | | | | | | | | | | | | | +| [CP](../configuration/optimization.md#chunked-prefill) | ✅ | | | | | | | | | | | | | | | | [APC](automatic_prefix_caching.md) | ✅ | ✅ | | | | | | | | | | | | | | | [LoRA](lora.md) | ✅ | ✅ | ✅ | | | | | | | | | | | | | | [SD](spec_decode.md) | ✅ | ✅ | ❌ | ✅ | | | | | | | | | | | | | CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | | | [pooling](../models/pooling_models.md) | 🟠\* | 🟠\* | ✅ | ❌ | ✅ | ✅ | | | | | | | | | | -| enc-dec | ❌ | [❌](gh-issue:7366) | ❌ | [❌](gh-issue:7366) | ✅ | ✅ | ✅ | | | | | | | | | +| enc-dec | ❌ | [❌](https://github.com/vllm-project/vllm/issues/7366) | ❌ | [❌](https://github.com/vllm-project/vllm/issues/7366) | ✅ | ✅ | ✅ | | | | | | | | | | logP | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | | | | prmpt logP | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | | | | | | | | async output | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | | | | multi-step | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | | | | -| [mm](multimodal_inputs.md) | ✅ | ✅ | [🟠](gh-pr:4194)^ | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | | -| best-of | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ✅ | ✅ | | | -| beam-search | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ❔ | ✅ | ✅ | | -| [prompt-embeds](prompt_embeds.md) | ✅ | [❌](gh-issue:25096) | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ | +| [mm](multimodal_inputs.md) | ✅ | ✅ | [🟠](https://github.com/vllm-project/vllm/pull/4194)^ | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | | +| best-of | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/7968) | ✅ | ✅ | | | +| beam-search | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | | +| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ | \* Chunked prefill and prefix caching are only applicable to last-token pooling. ^ LoRA is only applicable to the language backbone of multimodal models. -[](){ #feature-x-hardware } - ### Feature x Hardware -| Feature | Volta | Turing | Ampere | Ada | Hopper | CPU | AMD | TPU | -|-----------------------------------------------------------|---------------------|-----------|-----------|--------|------------|--------------------|--------|-----| -| [CP][chunked-prefill] | [❌](gh-issue:2729) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [APC](automatic_prefix_caching.md) | [❌](gh-issue:3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | -| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | -| [pooling](../models/pooling_models.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | -| enc-dec | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | -| [mm](multimodal_inputs.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | -| logP | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | -| prmpt logP | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | -| async output | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | -| multi-step | ✅ | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:8477) | ✅ | ❌ | -| best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | -| beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | -| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ? | [❌](gh-issue:25097) | +| Feature | Volta | Turing | Ampere | Ada | Hopper | CPU | AMD | TPU | Intel GPU | +|-----------------------------------------------------------|---------------------|-----------|-----------|--------|------------|--------------------|--------|-----| ------------| +| [CP](../configuration/optimization.md#chunked-prefill) | [❌](https://github.com/vllm-project/vllm/issues/2729) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [APC](automatic_prefix_caching.md) | [❌](https://github.com/vllm-project/vllm/issues/3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | [🟠](https://github.com/vllm-project/vllm/issues/26963) | +| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | [❌](https://github.com/vllm-project/vllm/issues/26970) | +| [pooling](../models/pooling_models.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | +| enc-dec | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | +| [mm](multimodal_inputs.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | [🟠](https://github.com/vllm-project/vllm/issues/26965) | +| logP | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | +| prmpt logP | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | +| async output | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | +| multi-step | ✅ | ✅ | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/8477) | ✅ | ❌ | ✅ | +| best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | +| beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | +| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/25097) | ✅ | diff --git a/docs/features/automatic_prefix_caching.md b/docs/features/automatic_prefix_caching.md index c529da684e36..3718a4b74eb2 100644 --- a/docs/features/automatic_prefix_caching.md +++ b/docs/features/automatic_prefix_caching.md @@ -11,7 +11,7 @@ Automatic Prefix Caching (APC in short) caches the KV cache of existing queries, Set `enable_prefix_caching=True` in vLLM engine to enable APC. Here is an example: - +[examples/offline_inference/automatic_prefix_caching.py](../../examples/offline_inference/automatic_prefix_caching.py) ## Example workloads diff --git a/docs/features/custom_logitsprocs.md b/docs/features/custom_logitsprocs.md index 201b340c5972..b8ad53863cd7 100644 --- a/docs/features/custom_logitsprocs.md +++ b/docs/features/custom_logitsprocs.md @@ -93,7 +93,6 @@ The contrived example below implements a custom logits processor which consumes ??? code "Example custom logits processor definition" ``` python - from typing import Optional import torch from vllm.config import VllmConfig from vllm.sampling_params import SamplingParams @@ -112,7 +111,7 @@ The contrived example below implements a custom logits processor which consumes """Never impacts greedy sampling""" return False - def update_state(self, batch_update: Optional[BatchUpdate]): + def update_state(self, batch_update: BatchUpdate | None): if not batch_update: return diff --git a/docs/features/disagg_prefill.md b/docs/features/disagg_prefill.md index fe065b52268a..3e8cb87e37d3 100644 --- a/docs/features/disagg_prefill.md +++ b/docs/features/disagg_prefill.md @@ -17,14 +17,14 @@ Two main reasons: ## Usage example -Please refer to for the example usage of disaggregated prefilling. +Please refer to [examples/online_serving/disaggregated_prefill.sh](../../examples/online_serving/disaggregated_prefill.sh) for the example usage of disaggregated prefilling. Now supports 5 types of connectors: -- **SharedStorageConnector**: refer to for the example usage of SharedStorageConnector disaggregated prefilling. -- **LMCacheConnectorV1**: refer to for the example usage of LMCacheConnectorV1 disaggregated prefilling which uses NIXL as the underlying KV transmission. -- **NixlConnector**: refer to for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. For detailed usage guide, see [NixlConnector Usage Guide](nixl_connector_usage.md). -- **P2pNcclConnector**: refer to for the example usage of P2pNcclConnector disaggregated prefilling. +- **SharedStorageConnector**: refer to [examples/offline_inference/disaggregated-prefill-v1/run.sh](../../examples/offline_inference/disaggregated-prefill-v1/run.sh) for the example usage of SharedStorageConnector disaggregated prefilling. +- **LMCacheConnectorV1**: refer to [examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh](../../examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh) for the example usage of LMCacheConnectorV1 disaggregated prefilling which uses NIXL as the underlying KV transmission. +- **NixlConnector**: refer to [tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh](../../tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh) for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv. For detailed usage guide, see [NixlConnector Usage Guide](nixl_connector_usage.md). +- **P2pNcclConnector**: refer to [examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh](../../examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh) for the example usage of P2pNcclConnector disaggregated prefilling. - **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as: ```bash @@ -45,7 +45,7 @@ For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as: ## Benchmarks -Please refer to for disaggregated prefilling benchmarks. +Please refer to [benchmarks/disagg_benchmarks](../../benchmarks/disagg_benchmarks) for disaggregated prefilling benchmarks. ## Development diff --git a/docs/features/lora.md b/docs/features/lora.md index db794b2ebd71..3a85b52d89b6 100644 --- a/docs/features/lora.md +++ b/docs/features/lora.md @@ -32,7 +32,7 @@ the third parameter is the path to the LoRA adapter. sampling_params = SamplingParams( temperature=0, max_tokens=256, - stop=["[/assistant]"] + stop=["[/assistant]"], ) prompts = [ @@ -43,11 +43,11 @@ the third parameter is the path to the LoRA adapter. outputs = llm.generate( prompts, sampling_params, - lora_request=LoRARequest("sql_adapter", 1, sql_lora_path) + lora_request=LoRARequest("sql_adapter", 1, sql_lora_path), ) ``` -Check out for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options. +Check out [examples/offline_inference/multilora_inference.py](../../examples/offline_inference/multilora_inference.py) for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options. ## Serving LoRA Adapters @@ -197,7 +197,7 @@ Alternatively, follow these example steps to implement your own plugin: lora_request = LoRARequest( lora_name=lora_name, lora_path=local_path, - lora_int_id=abs(hash(lora_name)) + lora_int_id=abs(hash(lora_name)), ) return lora_request ``` @@ -296,10 +296,7 @@ To this end, we allow registration of default multimodal LoRAs to handle this au if has_audio: question = f"<|audio|>{question}" chat = [ - { - "role": "user", - "content": question - } + {"role": "user", "content": question}, ] return tokenizer.apply_chat_template(chat, tokenize=False) diff --git a/docs/features/multimodal_inputs.md b/docs/features/multimodal_inputs.md index dcc5ea3b9096..caf458c24497 100644 --- a/docs/features/multimodal_inputs.md +++ b/docs/features/multimodal_inputs.md @@ -1,9 +1,9 @@ # Multimodal Inputs -This page teaches you how to pass multi-modal inputs to [multi-modal models][supported-mm-models] in vLLM. +This page teaches you how to pass multi-modal inputs to [multi-modal models](../models/supported_models.md#list-of-multimodal-language-models) in vLLM. !!! note - We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes, + We are actively iterating on multi-modal support. See [this RFC](https://github.com/vllm-project/vllm/issues/4194) for upcoming changes, and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests. !!! tip @@ -129,7 +129,7 @@ You can pass a single image to the `'image'` field of the multi-modal dictionary print(generated_text) ``` -Full example: +Full example: [examples/offline_inference/vision_language.py](../../examples/offline_inference/vision_language.py) To substitute multiple images inside the same text prompt, you can pass in a list of images instead: @@ -154,9 +154,7 @@ To substitute multiple images inside the same text prompt, you can pass in a lis outputs = llm.generate({ "prompt": prompt, - "multi_modal_data": { - "image": [image1, image2] - }, + "multi_modal_data": {"image": [image1, image2]}, }) for o in outputs: @@ -164,7 +162,7 @@ To substitute multiple images inside the same text prompt, you can pass in a lis print(generated_text) ``` -Full example: +Full example: [examples/offline_inference/vision_language_multi_image.py](../../examples/offline_inference/vision_language_multi_image.py) If using the [LLM.chat](../models/generative_models.md#llmchat) method, you can pass images directly in the message content using various formats: image URLs, PIL Image objects, or pre-computed embeddings: @@ -183,21 +181,24 @@ conversation = [ {"role": "assistant", "content": "Hello! How can I assist you today?"}, { "role": "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - },{ - "type": "image_pil", - "image_pil": image_pil - }, { - "type": "image_embeds", - "image_embeds": image_embeds - }, { - "type": "text", - "text": "What's in these images?" - }], + "content": [ + { + "type": "image_url", + "image_url": {"url": image_url}, + }, + { + "type": "image_pil", + "image_pil": image_pil, + }, + { + "type": "image_embeds", + "image_embeds": image_embeds, + }, + { + "type": "text", + "text": "What's in these images?", + }, + ], }, ] @@ -224,7 +225,10 @@ Multi-image input can be extended to perform video captioning. We show this with message = { "role": "user", "content": [ - {"type": "text", "text": "Describe this set of frames. Consider the frames to be a part of the same video."}, + { + "type": "text", + "text": "Describe this set of frames. Consider the frames to be a part of the same video.", + }, ], } for i in range(len(video_frames)): @@ -255,13 +259,13 @@ When loading RGBA images (images with transparency), vLLM converts them to RGB f # Custom black background for dark theme llm = LLM( model="llava-hf/llava-1.5-7b-hf", - media_io_kwargs={"image": {"rgba_background_color": [0, 0, 0]}} + media_io_kwargs={"image": {"rgba_background_color": [0, 0, 0]}}, ) # Custom brand color background (e.g., blue) llm = LLM( model="llava-hf/llava-1.5-7b-hf", - media_io_kwargs={"image": {"rgba_background_color": [0, 0, 255]}} + media_io_kwargs={"image": {"rgba_background_color": [0, 0, 255]}}, ) ``` @@ -294,20 +298,23 @@ Instead of NumPy arrays, you can also pass `'torch.Tensor'` instances, as shown limit_mm_per_prompt={"video": 1}, ) - sampling_params = SamplingParams( - max_tokens=1024, - ) + sampling_params = SamplingParams(max_tokens=1024) video_messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": [ + { + "role": "system", + "content": "You are a helpful assistant.", + }, + { + "role": "user", + "content": [ {"type": "text", "text": "describe this video."}, { "type": "video", "video": video_path, "total_pixels": 20480 * 28 * 28, - "min_pixels": 16 * 28 * 28 - } + "min_pixels": 16 * 28 * 28, + }, ] }, ] @@ -339,26 +346,32 @@ Instead of NumPy arrays, you can also pass `'torch.Tensor'` instances, as shown !!! note 'process_vision_info' is only applicable to Qwen2.5-VL and similar models. -Full example: +Full example: [examples/offline_inference/vision_language.py](../../examples/offline_inference/vision_language.py) ### Audio Inputs You can pass a tuple `(array, sampling_rate)` to the `'audio'` field of the multi-modal dictionary. -Full example: +Full example: [examples/offline_inference/audio_language.py](../../examples/offline_inference/audio_language.py) ### Embedding Inputs To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model, pass a tensor of shape `(num_items, feature_size, hidden_size of LM)` to the corresponding field of the multi-modal dictionary. +You must enable this feature via `enable_mm_embeds=True`. + +!!! warning + The vLLM engine may crash if incorrect shape of embeddings is passed. + Only enable this flag for trusted users! + ??? code ```python from vllm import LLM # Inference with image embeddings as input - llm = LLM(model="llava-hf/llava-1.5-7b-hf") + llm = LLM(model="llava-hf/llava-1.5-7b-hf", enable_mm_embeds=True) # Refer to the HuggingFace repo for the correct format to use prompt = "USER: \nWhat is the content of this image?\nASSISTANT:" @@ -390,7 +403,11 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd image_embeds = torch.load(...) # Qwen2-VL - llm = LLM("Qwen/Qwen2-VL-2B-Instruct", limit_mm_per_prompt={"image": 4}) + llm = LLM( + "Qwen/Qwen2-VL-2B-Instruct", + limit_mm_per_prompt={"image": 4}, + enable_mm_embeds=True, + ) mm_data = { "image": { "image_embeds": image_embeds, @@ -400,7 +417,12 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd } # MiniCPM-V - llm = LLM("openbmb/MiniCPM-V-2_6", trust_remote_code=True, limit_mm_per_prompt={"image": 4}) + llm = LLM( + "openbmb/MiniCPM-V-2_6", + trust_remote_code=True, + limit_mm_per_prompt={"image": 4}, + enable_mm_embeds=True, + ) mm_data = { "image": { "image_embeds": image_embeds, @@ -427,11 +449,11 @@ Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions A chat template is **required** to use Chat Completions API. For HF format models, the default chat template is defined inside `chat_template.json` or `tokenizer_config.json`. - If no default chat template is available, we will first look for a built-in fallback in . + If no default chat template is available, we will first look for a built-in fallback in [vllm/transformers_utils/chat_templates/registry.py](../../vllm/transformers_utils/chat_templates/registry.py). If no fallback is available, an error is raised and you have to provide the chat template manually via the `--chat-template` argument. - For certain models, we provide alternative chat templates inside . - For example, VLM2Vec uses which is different from the default one for Phi-3-Vision. + For certain models, we provide alternative chat templates inside [examples](../../examples). + For example, VLM2Vec uses [examples/template_vlm2vec_phi3v.jinja](../../examples/template_vlm2vec_phi3v.jinja) which is different from the default one for Phi-3-Vision. ### Image Inputs @@ -465,21 +487,24 @@ Then, you can use the OpenAI client as follows: chat_response = client.chat.completions.create( model="microsoft/Phi-3.5-vision-instruct", - messages=[{ - "role": "user", - "content": [ - # NOTE: The prompt formatting with the image token `` is not needed - # since the prompt will be processed automatically by the API server. - {"type": "text", "text": "What’s in this image?"}, - { - "type": "image_url", - "image_url": { - url": image_url + messages=[ + { + "role": "user", + "content": [ + # NOTE: The prompt formatting with the image token `` is not needed + # since the prompt will be processed automatically by the API server. + { + "type": "text", + "text": "What’s in this image?", }, - "uuid": image_url # Optional - }, - ], - }], + { + "type": "image_url", + "image_url": {"url": image_url}, + "uuid": image_url, # Optional + }, + ], + } + ], ) print("Chat completion output:", chat_response.choices[0].message.content) @@ -489,31 +514,32 @@ Then, you can use the OpenAI client as follows: chat_response = client.chat.completions.create( model="microsoft/Phi-3.5-vision-instruct", - messages=[{ - "role": "user", - "content": [ - {"type": "text", "text": "What are the animals in these images?"}, - { - "type": "image_url", - "image_url": { - "url": image_url_duck + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the animals in these images?", }, - "uuid": image_url_duck # Optional - }, - { - "type": "image_url", - "image_url": { - "url": image_url_lion + { + "type": "image_url", + "image_url": {"url": image_url_duck}, + "uuid": image_url_duck, # Optional }, - "uuid": image_url_lion # Optional - }, - ], - }], + { + "type": "image_url", + "image_url": {"url": image_url_lion}, + "uuid": image_url_lion, # Optional + }, + ], + } + ], ) print("Chat completion output:", chat_response.choices[0].message.content) ``` -Full example: +Full example: [examples/online_serving/openai_chat_completion_client_for_multimodal.py](../../examples/online_serving/openai_chat_completion_client_for_multimodal.py) !!! tip Loading from local file paths is also supported on vLLM: You can specify the allowed local media path via `--allowed-local-media-path` when launching the API server/engine, @@ -560,23 +586,22 @@ Then, you can use the OpenAI client as follows: ## Use video url in the payload chat_completion_from_url = client.chat.completions.create( - messages=[{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's in this video?" - }, - { - "type": "video_url", - "video_url": { - "url": video_url + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What's in this video?", }, - "uuid": video_url # Optional - }, - ], - }], + { + "type": "video_url", + "video_url": {"url": video_url}, + "uuid": video_url, # Optional + }, + ], + } + ], model=model, max_completion_tokens=64, ) @@ -585,7 +610,7 @@ Then, you can use the OpenAI client as follows: print("Chat completion output from image url:", result) ``` -Full example: +Full example: [examples/online_serving/openai_chat_completion_client_for_multimodal.py](../../examples/online_serving/openai_chat_completion_client_for_multimodal.py) !!! note By default, the timeout for fetching videos through HTTP URL is `30` seconds. @@ -652,23 +677,25 @@ Then, you can use the OpenAI client as follows: audio_base64 = encode_base64_content_from_url(audio_url) chat_completion_from_base64 = client.chat.completions.create( - messages=[{ - "role": "user", - "content": [ - { - "type": "text", - "text": "What's in this audio?" - }, - { - "type": "input_audio", - "input_audio": { - "data": audio_base64, - "format": "wav" + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What's in this audio?", }, - "uuid": audio_url # Optional - }, - ], - }], + { + "type": "input_audio", + "input_audio": { + "data": audio_base64, + "format": "wav", + }, + "uuid": audio_url, # Optional + }, + ], + }, + ], model=model, max_completion_tokens=64, ) @@ -683,22 +710,22 @@ Alternatively, you can pass `audio_url`, which is the audio counterpart of `imag ```python chat_completion_from_url = client.chat.completions.create( - messages=[{ - "role": "user", - "content": [ - { - "type": "text", - "text": "What's in this audio?" - }, - { - "type": "audio_url", - "audio_url": { - "url": audio_url + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What's in this audio?", }, - "uuid": audio_url # Optional - }, - ], - }], + { + "type": "audio_url", + "audio_url": {"url": audio_url}, + "uuid": audio_url, # Optional + }, + ], + } + ], model=model, max_completion_tokens=64, ) @@ -707,7 +734,7 @@ Alternatively, you can pass `audio_url`, which is the audio counterpart of `imag print("Chat completion output from audio url:", result) ``` -Full example: +Full example: [examples/online_serving/openai_chat_completion_client_for_multimodal.py](../../examples/online_serving/openai_chat_completion_client_for_multimodal.py) !!! note By default, the timeout for fetching audios through HTTP URL is `10` seconds. @@ -720,7 +747,13 @@ Full example: +[examples/offline_inference/prompt_embed_inference.py](../../examples/offline_inference/prompt_embed_inference.py) ## Online Serving -Our OpenAI-compatible server accepts prompt embeddings inputs via the [Completions API](https://platform.openai.com/docs/api-reference/completions). Prompt embeddings inputs are added via a new `'prompt_embeds'` key in the JSON package. +Our OpenAI-compatible server accepts prompt embeddings inputs via the [Completions API](https://platform.openai.com/docs/api-reference/completions). Prompt embeddings inputs are added via a new `'prompt_embeds'` key in the JSON package and are enabled by the `--enable-prompt-embeds` flag in `vllm serve`. When a mixture of `'prompt_embeds'` and `'prompt'` inputs are provided in a single request, the prompt embeds are always returned first. Prompt embeddings are passed in as base64 encoded torch tensors. +!!! warning + The vLLM engine may crash if incorrect shape of embeddings is passed. + Only enable this flag for trusted users! + ### Transformers Inputs via OpenAI Client First, launch the OpenAI-compatible server: @@ -37,4 +41,4 @@ vllm serve meta-llama/Llama-3.2-1B-Instruct --runner generate \ Then, you can use the OpenAI client as follows: - +[examples/online_serving/prompt_embed_inference_with_openai_client.py](../../examples/online_serving/prompt_embed_inference_with_openai_client.py) diff --git a/docs/features/quantization/README.md b/docs/features/quantization/README.md index 4c8377871e14..74f005c496ee 100644 --- a/docs/features/quantization/README.md +++ b/docs/features/quantization/README.md @@ -64,4 +64,4 @@ th:not(:first-child) { !!! note This compatibility chart is subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods. - For the most up-to-date information on hardware support and quantization methods, please refer to or consult with the vLLM development team. + For the most up-to-date information on hardware support and quantization methods, please refer to [vllm/model_executor/layers/quantization](../../../vllm/model_executor/layers/quantization) or consult with the vLLM development team. diff --git a/docs/features/quantization/auto_awq.md b/docs/features/quantization/auto_awq.md index fc998387d29a..e77e8b5a1f41 100644 --- a/docs/features/quantization/auto_awq.md +++ b/docs/features/quantization/auto_awq.md @@ -1,5 +1,9 @@ # AutoAWQ +> ⚠️ **Warning:** + The `AutoAWQ` library is deprecated. This functionality has been adopted by the vLLM project in [`llm-compressor`](https://github.com/vllm-project/llm-compressor/tree/main/examples/awq). + For the recommended quantization workflow, please see the AWQ examples in [`llm-compressor`](https://github.com/vllm-project/llm-compressor/tree/main/examples/awq). For more details on the deprecation, refer to the original [AutoAWQ repository](https://github.com/casper-hansen/AutoAWQ). + To create a new 4-bit quantized model, you can leverage [AutoAWQ](https://github.com/casper-hansen/AutoAWQ). Quantization reduces the model's precision from BF16/FP16 to INT4 which effectively reduces the total model memory footprint. The main benefits are lower latency and memory usage. @@ -18,13 +22,15 @@ After installing AutoAWQ, you are ready to quantize a model. Please refer to the from awq import AutoAWQForCausalLM from transformers import AutoTokenizer - model_path = 'mistralai/Mistral-7B-Instruct-v0.2' - quant_path = 'mistral-instruct-v0.2-awq' - quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } + model_path = "mistralai/Mistral-7B-Instruct-v0.2" + quant_path = "mistral-instruct-v0.2-awq" + quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"} # Load model model = AutoAWQForCausalLM.from_pretrained( - model_path, **{"low_cpu_mem_usage": True, "use_cache": False} + model_path, + low_cpu_mem_usage=True, + use_cache=False, ) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) diff --git a/docs/features/quantization/auto_round.md b/docs/features/quantization/auto_round.md index ac766d5e2922..9c14f362b663 100644 --- a/docs/features/quantization/auto_round.md +++ b/docs/features/quantization/auto_round.md @@ -58,7 +58,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from auto_round import AutoRound model_name = "Qwen/Qwen3-0.6B" -model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto") +model = AutoModelForCausalLM.from_pretrained(model_name, dtype="auto") tokenizer = AutoTokenizer.from_pretrained(model_name) bits, group_size, sym = 4, 128, True diff --git a/docs/features/quantization/bitblas.md b/docs/features/quantization/bitblas.md index 53b689ad53ff..c3a127657622 100644 --- a/docs/features/quantization/bitblas.md +++ b/docs/features/quantization/bitblas.md @@ -34,7 +34,7 @@ llm = LLM( model=model_id, dtype=torch.bfloat16, trust_remote_code=True, - quantization="bitblas" + quantization="bitblas", ) ``` @@ -53,6 +53,6 @@ llm = LLM( dtype=torch.float16, trust_remote_code=True, quantization="bitblas", - max_model_len=1024 + max_model_len=1024, ) ``` diff --git a/docs/features/quantization/bnb.md b/docs/features/quantization/bnb.md index 3b15a6072d47..2348c7739c06 100644 --- a/docs/features/quantization/bnb.md +++ b/docs/features/quantization/bnb.md @@ -27,7 +27,7 @@ model_id = "unsloth/tinyllama-bnb-4bit" llm = LLM( model=model_id, dtype=torch.bfloat16, - trust_remote_code=True + trust_remote_code=True, ) ``` @@ -43,7 +43,7 @@ llm = LLM( model=model_id, dtype=torch.bfloat16, trust_remote_code=True, - quantization="bitsandbytes" + quantization="bitsandbytes", ) ``` diff --git a/docs/features/quantization/fp8.md b/docs/features/quantization/fp8.md index 834c03cbe05b..0c5111fb8af0 100644 --- a/docs/features/quantization/fp8.md +++ b/docs/features/quantization/fp8.md @@ -41,7 +41,9 @@ from transformers import AutoTokenizer, AutoModelForCausalLM MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, device_map="auto", torch_dtype="auto", + MODEL_ID, + device_map="auto", + dtype="auto", ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) ``` @@ -63,7 +65,10 @@ Since simple RTN does not require data for weight quantization and the activatio # Configure the simple PTQ quantization recipe = QuantizationModifier( - targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"]) + targets="Linear", + scheme="FP8_DYNAMIC", + ignore=["lm_head"], + ) # Apply the quantization algorithm. oneshot(model=model, recipe=recipe) diff --git a/docs/features/quantization/gguf.md b/docs/features/quantization/gguf.md index 2a1c3bdd775f..2a731e9b7e03 100644 --- a/docs/features/quantization/gguf.md +++ b/docs/features/quantization/gguf.md @@ -47,15 +47,15 @@ You can also use the GGUF model directly through the LLM entrypoint: conversation = [ { "role": "system", - "content": "You are a helpful assistant" + "content": "You are a helpful assistant", }, { "role": "user", - "content": "Hello" + "content": "Hello", }, { "role": "assistant", - "content": "Hello! How can I assist you today?" + "content": "Hello! How can I assist you today?", }, { "role": "user", @@ -67,8 +67,10 @@ You can also use the GGUF model directly through the LLM entrypoint: sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. - llm = LLM(model="./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", - tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0") + llm = LLM( + model="./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", + tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + ) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.chat(conversation, sampling_params) diff --git a/docs/features/quantization/gptqmodel.md b/docs/features/quantization/gptqmodel.md index 47cb2d65bae4..f14a931725da 100644 --- a/docs/features/quantization/gptqmodel.md +++ b/docs/features/quantization/gptqmodel.md @@ -40,7 +40,7 @@ Here is an example of how to quantize `meta-llama/Llama-3.2-1B-Instruct`: calibration_dataset = load_dataset( "allenai/c4", data_files="en/c4-train.00001-of-01024.json.gz", - split="train" + split="train", ).select(range(1024))["text"] quant_config = QuantizeConfig(bits=4, group_size=128) diff --git a/docs/features/quantization/int4.md b/docs/features/quantization/int4.md index d6fdac7b07f7..035e7ea291f9 100644 --- a/docs/features/quantization/int4.md +++ b/docs/features/quantization/int4.md @@ -39,7 +39,9 @@ from transformers import AutoTokenizer, AutoModelForCausalLM MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, device_map="auto", torch_dtype="auto", + MODEL_ID, + device_map="auto", + dtype="auto", ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) ``` @@ -166,7 +168,7 @@ The following is an example of an expanded quantization recipe you can tune to y }, ignore=["lm_head"], update_size=NUM_CALIBRATION_SAMPLES, - dampening_frac=0.01 + dampening_frac=0.01, ) ``` diff --git a/docs/features/quantization/int8.md b/docs/features/quantization/int8.md index af3650e701ad..ec8a77f74ffe 100644 --- a/docs/features/quantization/int8.md +++ b/docs/features/quantization/int8.md @@ -44,7 +44,9 @@ from transformers import AutoTokenizer, AutoModelForCausalLM MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, device_map="auto", torch_dtype="auto", + MODEL_ID, + device_map="auto", + dtype="auto", ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) ``` diff --git a/docs/features/quantization/modelopt.md b/docs/features/quantization/modelopt.md index 39ae03b1bdac..c48ccb719a79 100644 --- a/docs/features/quantization/modelopt.md +++ b/docs/features/quantization/modelopt.md @@ -56,9 +56,9 @@ The quantized checkpoint can then be deployed with vLLM. As an example, the foll from vllm import LLM, SamplingParams def main(): - model_id = "nvidia/Llama-3.1-8B-Instruct-FP8" - # Ensure you specify quantization='modelopt' when loading the modelopt checkpoint + + # Ensure you specify quantization="modelopt" when loading the modelopt checkpoint llm = LLM(model=model_id, quantization="modelopt", trust_remote_code=True) sampling_params = SamplingParams(temperature=0.8, top_p=0.9) diff --git a/docs/features/quantization/quantized_kvcache.md b/docs/features/quantization/quantized_kvcache.md index b2b417309e92..56cf057678be 100644 --- a/docs/features/quantization/quantized_kvcache.md +++ b/docs/features/quantization/quantized_kvcache.md @@ -41,9 +41,11 @@ Here is an example of how to enable FP8 quantization: from vllm import LLM, SamplingParams sampling_params = SamplingParams(temperature=0.7, top_p=0.8) - llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", - kv_cache_dtype="fp8", - calculate_kv_scales=True) + llm = LLM( + model="meta-llama/Llama-2-7b-chat-hf", + kv_cache_dtype="fp8", + calculate_kv_scales=True, + ) prompt = "London is the capital of" out = llm.generate(prompt, sampling_params)[0].outputs[0].text print(out) @@ -80,7 +82,7 @@ Here's a complete example using `meta-llama/Llama-3.1-8B-Instruct` (most models # Select model and load it MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" - model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", torch_dtype="auto") + model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", dtype="auto") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # Select calibration dataset diff --git a/docs/features/quantization/quark.md b/docs/features/quantization/quark.md index 85b7d8ec84ed..385e3bbb8712 100644 --- a/docs/features/quantization/quark.md +++ b/docs/features/quantization/quark.md @@ -48,7 +48,9 @@ to fetch model and tokenizer. MAX_SEQ_LEN = 512 model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, device_map="auto", torch_dtype="auto", + MODEL_ID, + device_map="auto", + dtype="auto", ) model.eval() @@ -75,10 +77,18 @@ to [Adding Calibration Datasets](https://quark.docs.amd.com/latest/pytorch/calib dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") text_data = dataset["text"][:NUM_CALIBRATION_DATA] - tokenized_outputs = tokenizer(text_data, return_tensors="pt", - padding=True, truncation=True, max_length=MAX_SEQ_LEN) - calib_dataloader = DataLoader(tokenized_outputs['input_ids'], - batch_size=BATCH_SIZE, drop_last=True) + tokenized_outputs = tokenizer( + text_data, + return_tensors="pt", + padding=True, + truncation=True, + max_length=MAX_SEQ_LEN, + ) + calib_dataloader = DataLoader( + tokenized_outputs['input_ids'], + batch_size=BATCH_SIZE, + drop_last=True, + ) ``` ### 3. Set the Quantization Configuration @@ -103,26 +113,32 @@ kv-cache and the quantization algorithm is AutoSmoothQuant. load_quant_algo_config_from_file) # Define fp8/per-tensor/static spec. - FP8_PER_TENSOR_SPEC = FP8E4M3PerTensorSpec(observer_method="min_max", - is_dynamic=False).to_quantization_spec() + FP8_PER_TENSOR_SPEC = FP8E4M3PerTensorSpec( + observer_method="min_max", + is_dynamic=False, + ).to_quantization_spec() # Define global quantization config, input tensors and weight apply FP8_PER_TENSOR_SPEC. - global_quant_config = QuantizationConfig(input_tensors=FP8_PER_TENSOR_SPEC, - weight=FP8_PER_TENSOR_SPEC) + global_quant_config = QuantizationConfig( + input_tensors=FP8_PER_TENSOR_SPEC, + weight=FP8_PER_TENSOR_SPEC, + ) # Define quantization config for kv-cache layers, output tensors apply FP8_PER_TENSOR_SPEC. KV_CACHE_SPEC = FP8_PER_TENSOR_SPEC kv_cache_layer_names_for_llama = ["*k_proj", "*v_proj"] - kv_cache_quant_config = {name : - QuantizationConfig(input_tensors=global_quant_config.input_tensors, - weight=global_quant_config.weight, - output_tensors=KV_CACHE_SPEC) - for name in kv_cache_layer_names_for_llama} + kv_cache_quant_config = { + name: QuantizationConfig( + input_tensors=global_quant_config.input_tensors, + weight=global_quant_config.weight, + output_tensors=KV_CACHE_SPEC, + ) + for name in kv_cache_layer_names_for_llama + } layer_quant_config = kv_cache_quant_config.copy() # Define algorithm config by config file. - LLAMA_AUTOSMOOTHQUANT_CONFIG_FILE = - 'examples/torch/language_modeling/llm_ptq/models/llama/autosmoothquant_config.json' + LLAMA_AUTOSMOOTHQUANT_CONFIG_FILE = "examples/torch/language_modeling/llm_ptq/models/llama/autosmoothquant_config.json" algo_config = load_quant_algo_config_from_file(LLAMA_AUTOSMOOTHQUANT_CONFIG_FILE) EXCLUDE_LAYERS = ["lm_head"] @@ -131,7 +147,8 @@ kv-cache and the quantization algorithm is AutoSmoothQuant. layer_quant_config=layer_quant_config, kv_cache_quant_config=kv_cache_quant_config, exclude=EXCLUDE_LAYERS, - algo_config=algo_config) + algo_config=algo_config, + ) ``` ### 4. Quantize the Model and Export @@ -165,8 +182,11 @@ for more exporting format details. EXPORT_DIR = MODEL_ID.split("/")[1] + "-w-fp8-a-fp8-kvcache-fp8-pertensor-autosmoothquant" exporter = ModelExporter(config=export_config, export_dir=EXPORT_DIR) with torch.no_grad(): - exporter.export_safetensors_model(freezed_model, - quant_config=quant_config, tokenizer=tokenizer) + exporter.export_safetensors_model( + freezed_model, + quant_config=quant_config, + tokenizer=tokenizer, + ) ``` ### 5. Evaluation in vLLM @@ -189,8 +209,11 @@ Now, you can load and run the Quark quantized model directly through the LLM ent sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. - llm = LLM(model="Llama-2-70b-chat-hf-w-fp8-a-fp8-kvcache-fp8-pertensor-autosmoothquant", - kv_cache_dtype='fp8',quantization='quark') + llm = LLM( + model="Llama-2-70b-chat-hf-w-fp8-a-fp8-kvcache-fp8-pertensor-autosmoothquant", + kv_cache_dtype="fp8", + quantization="quark", + ) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/docs/features/quantization/torchao.md b/docs/features/quantization/torchao.md index 693244599701..b95b560882bb 100644 --- a/docs/features/quantization/torchao.md +++ b/docs/features/quantization/torchao.md @@ -27,7 +27,7 @@ You can quantize your own huggingface model with torchao, e.g. [transformers](ht quantization_config = TorchAoConfig(Int8WeightOnlyConfig()) quantized_model = AutoModelForCausalLM.from_pretrained( model_name, - torch_dtype="auto", + dtype="auto", device_map="auto", quantization_config=quantization_config ) diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md index 85681669dfb2..302d1161c902 100644 --- a/docs/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -11,6 +11,9 @@ vLLM currently supports the following reasoning models: | Model Series | Parser Name | Structured Output Support | Tool Calling | |--------------|-------------|------------------|-------------| | [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `json`, `regex` | ❌ | +| [DeepSeek-V3.1](https://huggingface.co/collections/deepseek-ai/deepseek-v31-68a491bed32bd77e7fca048f) | `deepseek_v3` | `json`, `regex` | ❌ | +| [ERNIE-4.5-VL series](https://huggingface.co/baidu/ERNIE-4.5-VL-28B-A3B-PT) | `ernie45` | `json`, `regex` | ❌ | +| [ERNIE-4.5-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking) | `ernie45` | `json`, `regex` | ✅ | | [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `json`, `regex` | ✅ | | [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ | | [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `json`, `regex` | ✅ | @@ -18,8 +21,9 @@ vLLM currently supports the following reasoning models: | [GLM-4.5 series](https://huggingface.co/collections/zai-org/glm-45-687c621d34bda8c9e4bf503b) | `glm45` | `json`, `regex` | ✅ | !!! note - IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. + IBM Granite 3.2 and DeepSeek-V3.1 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. The reasoning feature for the Qwen3 series is enabled by default. To disable it, you must pass `enable_thinking=False` in your `chat_template_kwargs`. + DeepSeek-V3.1 tool calling is supported in non-thinking mode. ## Quickstart @@ -115,9 +119,11 @@ OpenAI Python client library does not officially support `reasoning_content` att # For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}` # For Qwen3 series, if you want to disable thinking in reasoning mode, add: # extra_body={"chat_template_kwargs": {"enable_thinking": False}} - stream = client.chat.completions.create(model=model, - messages=messages, - stream=True) + stream = client.chat.completions.create( + model=model, + messages=messages, + stream=True, + ) print("client: Start streaming chat completions...") printed_reasoning_content = False @@ -157,27 +163,29 @@ The reasoning content is also available when both tool calling and the reasoning client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy") - tools = [{ - "type": "function", - "function": { - "name": "get_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": {"type": "string", "description": "City and state, e.g., 'San Francisco, CA'"}, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]} - }, - "required": ["location", "unit"] - } + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City and state, e.g., 'San Francisco, CA'"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location", "unit"], + } + }, } - }] + ] response = client.chat.completions.create( model=client.models.list().data[0].id, messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}], tools=tools, - tool_choice="auto" + tool_choice="auto", ) print(response) @@ -188,7 +196,7 @@ The reasoning content is also available when both tool calling and the reasoning print(f"Arguments: {tool_call.arguments}") ``` -For more examples, please refer to . +For more examples, please refer to [examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py](../../examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py). ## Limitations @@ -196,7 +204,7 @@ For more examples, please refer to . +You can add a new `ReasoningParser` similar to [vllm/reasoning/deepseek_r1_reasoning_parser.py](../../vllm/reasoning/deepseek_r1_reasoning_parser.py). ??? code @@ -223,7 +231,7 @@ You can add a new `ReasoningParser` similar to Union[DeltaMessage, None]: + ) -> DeltaMessage | None: """ Instance method that should be implemented for extracting reasoning from an incomplete response; for use when handling reasoning calls and @@ -233,8 +241,10 @@ You can add a new `ReasoningParser` similar to tuple[Optional[str], Optional[str]]: + self, + model_output: str, + request: ChatCompletionRequest | ResponsesRequest, + ) -> tuple[str | None, str | None]: """ Extract reasoning content from a complete model-generated string. @@ -254,7 +264,7 @@ You can add a new `ReasoningParser` similar to . +Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in [vllm/reasoning/deepseek_r1_reasoning_parser.py](../../vllm/reasoning/deepseek_r1_reasoning_parser.py). ??? code @@ -272,10 +282,10 @@ Additionally, to enable structured output, you'll need to create a new `Reasoner @classmethod def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner: - return cls(start_token_id=tokenizer.encode( - "", add_special_tokens=False)[0], - end_token_id=tokenizer.encode("", - add_special_tokens=False)[0]) + return cls( + start_token_id=tokenizer.encode("", add_special_tokens=False)[0], + end_token_id=tokenizer.encode("", add_special_tokens=False)[0], + ) def is_reasoning_end(self, input_ids: list[int]) -> bool: return self.end_token_id in input_ids diff --git a/docs/features/spec_decode.md b/docs/features/spec_decode.md index 25c308a6ff20..ab72c7d97b7a 100644 --- a/docs/features/spec_decode.md +++ b/docs/features/spec_decode.md @@ -3,7 +3,7 @@ !!! warning Please note that speculative decoding in vLLM is not yet optimized and does not usually yield inter-token latency reductions for all prompt datasets or sampling parameters. - The work to optimize it is ongoing and can be followed here: + The work to optimize it is ongoing and can be followed here: !!! warning Currently, speculative decoding in vLLM is not compatible with pipeline parallelism. @@ -183,7 +183,7 @@ A variety of speculative models of this type are available on HF hub: ## Speculating using EAGLE based draft models The following code configures vLLM to use speculative decoding where proposals are generated by -an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model. A more detailed example for offline mode, including how to extract request level acceptance rate, can be found [here](gh-file:examples/offline_inference/eagle.py). +an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model. A more detailed example for offline mode, including how to extract request level acceptance rate, can be found [here](../../examples/offline_inference/spec_decode.py). ??? code @@ -218,8 +218,8 @@ an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https A few important things to consider when using the EAGLE based draft models: 1. The EAGLE draft models available in the [HF repository for EAGLE models](https://huggingface.co/yuhuili) should - be able to be loaded and used directly by vLLM after . - If you are using vllm version before , please use the + be able to be loaded and used directly by vLLM after . + If you are using vllm version before , please use the [script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model, and specify `"model": "path/to/modified/eagle/model"` in `speculative_config`. If weight-loading problems still occur when using the latest version of vLLM, please leave a comment or raise an issue. @@ -229,7 +229,7 @@ A few important things to consider when using the EAGLE based draft models: 3. When using EAGLE-based speculators with vLLM, the observed speedup is lower than what is reported in the reference implementation [here](https://github.com/SafeAILab/EAGLE). This issue is under - investigation and tracked here: . + investigation and tracked here: . 4. When using EAGLE-3 based draft model, option "method" must be set to "eagle3". That is, to specify `"method": "eagle3"` in `speculative_config`. @@ -267,7 +267,7 @@ speculative decoding, breaking down the guarantees into three key areas: > distribution. [View Test Code](https://github.com/vllm-project/vllm/blob/47b65a550866c7ffbd076ecb74106714838ce7da/tests/samplers/test_rejection_sampler.py#L252) > - **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling > without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler, - > provides a lossless guarantee. Almost all of the tests in . + > provides a lossless guarantee. Almost all of the tests in [tests/spec_decode/e2e](../../tests/spec_decode/e2e). > verify this property using [this assertion implementation](https://github.com/vllm-project/vllm/blob/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e/conftest.py#L291) 3. **vLLM Logprob Stability** @@ -289,4 +289,4 @@ For mitigation strategies, please refer to the FAQ entry *Can the output of a pr - [A Hacker's Guide to Speculative Decoding in vLLM](https://www.youtube.com/watch?v=9wNAgpX6z_4) - [What is Lookahead Scheduling in vLLM?](https://docs.google.com/document/d/1Z9TvqzzBPnh5WHcRwjvK2UEeFeq5zMZb5mFE8jR0HCs/edit#heading=h.1fjfb0donq5a) - [Information on batch expansion](https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit#heading=h.kk7dq05lc6q8) -- [Dynamic speculative decoding](gh-issue:4565) +- [Dynamic speculative decoding](https://github.com/vllm-project/vllm/issues/4565) diff --git a/docs/features/structured_outputs.md b/docs/features/structured_outputs.md index 901d87e7ed3d..9e1da37ca962 100644 --- a/docs/features/structured_outputs.md +++ b/docs/features/structured_outputs.md @@ -298,7 +298,7 @@ Step #2: explanation="Next, let's isolate 'x' by dividing both sides of the equa Answer: x = -29/8 ``` -An example of using `structural_tag` can be found here: +An example of using `structural_tag` can be found here: [examples/online_serving/structured_outputs](../../examples/online_serving/structured_outputs) ## Offline Inference diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index 6a0bcfac66d0..228619343c9d 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -27,27 +27,29 @@ Next, make a request that triggers the model to use the available tools: return f"Getting the weather for {location} in {unit}..." tool_functions = {"get_weather": get_weather} - tools = [{ - "type": "function", - "function": { - "name": "get_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": {"type": "string", "description": "City and state, e.g., 'San Francisco, CA'"}, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]} + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City and state, e.g., 'San Francisco, CA'"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]} + }, + "required": ["location", "unit"], }, - "required": ["location", "unit"] - } - } - }] + }, + }, + ] response = client.chat.completions.create( model=client.models.list().data[0].id, messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}], tools=tools, - tool_choice="auto" + tool_choice="auto", ) tool_call = response.choices[0].message.tool_calls[0].function @@ -145,16 +147,23 @@ Supported models: Known issues: 1. Mistral 7B struggles to generate parallel tool calls correctly. -2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is +2. **For Transformers tokenization backend only**: Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is much shorter than what vLLM generates. Since an exception is thrown when this condition is not met, the following additional chat templates are provided: - * - this is the "official" Mistral chat template, but tweaked so that + * [examples/tool_chat_template_mistral.jinja](../../examples/tool_chat_template_mistral.jinja) - this is the "official" Mistral chat template, but tweaked so that it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits) - * - this is a "better" version that adds a tool-use system prompt + * [examples/tool_chat_template_mistral_parallel.jinja](../../examples/tool_chat_template_mistral_parallel.jinja) - this is a "better" version that adds a tool-use system prompt when tools are provided, that results in much better reliability when working with parallel tool calling. -Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` +Recommended flags: + +1. To use [mistral-common](https://github.com/mistralai/mistral-common) the official Mistral tokenization backend: + + `--tokenizer_mode mistral --config_format mistral --load_format mistral --tool-call-parser mistral` + +2. To use the default Transformers tokenization backend: + `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` ### Llama Models (`llama3_json`) @@ -178,16 +187,16 @@ Known issues: VLLM provides two JSON-based chat templates for Llama 3.1 and 3.2: -* - this is the "official" chat template for the Llama 3.1 +* [examples/tool_chat_template_llama3.1_json.jinja](../../examples/tool_chat_template_llama3.1_json.jinja) - this is the "official" chat template for the Llama 3.1 models, but tweaked so that it works better with vLLM. -* - this extends upon the Llama 3.1 chat template by adding support for +* [examples/tool_chat_template_llama3.2_json.jinja](../../examples/tool_chat_template_llama3.2_json.jinja) - this extends upon the Llama 3.1 chat template by adding support for images. Recommended flags: `--tool-call-parser llama3_json --chat-template {see_above}` VLLM also provides a pythonic and JSON-based chat template for Llama 4, but pythonic tool calling is recommended: -* - this is based on the [official chat template](https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/) for the Llama 4 models. +* [examples/tool_chat_template_llama4_pythonic.jinja](../../examples/tool_chat_template_llama4_pythonic.jinja) - this is based on the [official chat template](https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/) for the Llama 4 models. For Llama 4 model, use `--tool-call-parser llama4_pythonic --chat-template examples/tool_chat_template_llama4_pythonic.jinja`. @@ -203,7 +212,7 @@ Supported models: Recommended flags: `--tool-call-parser granite --chat-template examples/tool_chat_template_granite.jinja` - : this is a modified chat template from the original on Hugging Face. Parallel function calls are supported. + [examples/tool_chat_template_granite.jinja](../../examples/tool_chat_template_granite.jinja): this is a modified chat template from the original on Hugging Face. Parallel function calls are supported. * `ibm-granite/granite-3.1-8b-instruct` @@ -215,7 +224,7 @@ Supported models: Recommended flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja` - : this is a modified chat template from the original on Hugging Face, which is not vLLM-compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. + [examples/tool_chat_template_granite_20b_fc.jinja](../../examples/tool_chat_template_granite_20b_fc.jinja): this is a modified chat template from the original on Hugging Face, which is not vLLM-compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. ### InternLM Models (`internlm`) @@ -273,8 +282,8 @@ Flags: `--tool-call-parser hermes` Supported models: -* `MiniMaxAi/MiniMax-M1-40k` (use with ) -* `MiniMaxAi/MiniMax-M1-80k` (use with ) +* `MiniMaxAi/MiniMax-M1-40k` (use with [examples/tool_chat_template_minimax_m1.jinja](../../examples/tool_chat_template_minimax_m1.jinja)) +* `MiniMaxAi/MiniMax-M1-80k` (use with [examples/tool_chat_template_minimax_m1.jinja](../../examples/tool_chat_template_minimax_m1.jinja)) Flags: `--tool-call-parser minimax --chat-template examples/tool_chat_template_minimax_m1.jinja` @@ -282,8 +291,8 @@ Flags: `--tool-call-parser minimax --chat-template examples/tool_chat_template_m Supported models: -* `deepseek-ai/DeepSeek-V3-0324` (use with ) -* `deepseek-ai/DeepSeek-R1-0528` (use with ) +* `deepseek-ai/DeepSeek-V3-0324` (use with [examples/tool_chat_template_deepseekv3.jinja](../../examples/tool_chat_template_deepseekv3.jinja)) +* `deepseek-ai/DeepSeek-R1-0528` (use with [examples/tool_chat_template_deepseekr1.jinja](../../examples/tool_chat_template_deepseekr1.jinja)) Flags: `--tool-call-parser deepseek_v3 --chat-template {see_above}` @@ -291,7 +300,7 @@ Flags: `--tool-call-parser deepseek_v3 --chat-template {see_above}` Supported models: -* `deepseek-ai/DeepSeek-V3.1` (use with ) +* `deepseek-ai/DeepSeek-V3.1` (use with [examples/tool_chat_template_deepseekv31.jinja](../../examples/tool_chat_template_deepseekv31.jinja)) Flags: `--tool-call-parser deepseek_v31 --chat-template {see_above}` @@ -343,6 +352,16 @@ Supported models: Flags: `--tool-call-parser qwen3_xml` +### Olmo 3 Models (`olmo3`) + +Olmo 3 models output tool calls in a format that is very similar to the one expected by the `pythonic` parser (see below), with a few differences. Each tool call is a pythonic string, but the parallel tool calls are newline-delimited, and the calls are wrapped within XML tags as `..`. In addition, the parser also allows JSON boolean and null literals (`true`, `false`, and `null`) in addition to the pythonic ones (`True`, `False`, and `None`). + +Supported models: + +* TODO (will be updated after Olmo 3 release) + +Flags: `--tool-call-parser olmo3` + ### Models with Pythonic Tool Calls (`pythonic`) A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models. @@ -360,12 +379,12 @@ Limitations: Example supported models: -* `meta-llama/Llama-3.2-1B-Instruct` ⚠️ (use with ) -* `meta-llama/Llama-3.2-3B-Instruct` ⚠️ (use with ) -* `Team-ACE/ToolACE-8B` (use with ) -* `fixie-ai/ultravox-v0_4-ToolACE-8B` (use with ) -* `meta-llama/Llama-4-Scout-17B-16E-Instruct` ⚠️ (use with ) -* `meta-llama/Llama-4-Maverick-17B-128E-Instruct` ⚠️ (use with ) +* `meta-llama/Llama-3.2-1B-Instruct` ⚠️ (use with [examples/tool_chat_template_llama3.2_pythonic.jinja](../../examples/tool_chat_template_llama3.2_pythonic.jinja)) +* `meta-llama/Llama-3.2-3B-Instruct` ⚠️ (use with [examples/tool_chat_template_llama3.2_pythonic.jinja](../../examples/tool_chat_template_llama3.2_pythonic.jinja)) +* `Team-ACE/ToolACE-8B` (use with [examples/tool_chat_template_toolace.jinja](../../examples/tool_chat_template_toolace.jinja)) +* `fixie-ai/ultravox-v0_4-ToolACE-8B` (use with [examples/tool_chat_template_toolace.jinja](../../examples/tool_chat_template_toolace.jinja)) +* `meta-llama/Llama-4-Scout-17B-16E-Instruct` ⚠️ (use with [examples/tool_chat_template_llama4_pythonic.jinja](../../examples/tool_chat_template_llama4_pythonic.jinja)) +* `meta-llama/Llama-4-Maverick-17B-128E-Instruct` ⚠️ (use with [examples/tool_chat_template_llama4_pythonic.jinja](../../examples/tool_chat_template_llama4_pythonic.jinja)) Flags: `--tool-call-parser pythonic --chat-template {see_above}` @@ -374,7 +393,7 @@ Flags: `--tool-call-parser pythonic --chat-template {see_above}` ## How to Write a Tool Parser Plugin -A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in . +A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in [vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py](../../vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py). Here is a summary of a plugin file: @@ -395,8 +414,7 @@ Here is a summary of a plugin file: # adjust request. e.g.: set skip special tokens # to False for tool call output. - def adjust_request( - self, request: ChatCompletionRequest) -> ChatCompletionRequest: + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: return request # implement the tool call parse for stream call @@ -409,7 +427,7 @@ Here is a summary of a plugin file: current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: return delta # implement the tool parse for non-stream call diff --git a/docs/getting_started/installation/cpu/apple.inc.md b/docs/getting_started/installation/cpu.apple.inc.md similarity index 100% rename from docs/getting_started/installation/cpu/apple.inc.md rename to docs/getting_started/installation/cpu.apple.inc.md diff --git a/docs/getting_started/installation/cpu/arm.inc.md b/docs/getting_started/installation/cpu.arm.inc.md similarity index 56% rename from docs/getting_started/installation/cpu/arm.inc.md rename to docs/getting_started/installation/cpu.arm.inc.md index e45baa0aa493..9cae9ed1a212 100644 --- a/docs/getting_started/installation/cpu/arm.inc.md +++ b/docs/getting_started/installation/cpu.arm.inc.md @@ -23,7 +23,46 @@ ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes. # --8<-- [end:pre-built-wheels] # --8<-- [start:build-wheel-from-source] ---8<-- "docs/getting_started/installation/cpu/build.inc.md" +First, install the recommended compiler. We recommend using `gcc/g++ >= 12.3.0` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run: + +```bash +sudo apt-get update -y +sudo apt-get install -y --no-install-recommends ccache git curl wget ca-certificates gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 jq lsof +sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 +``` + +Second, clone the vLLM project: + +```bash +git clone https://github.com/vllm-project/vllm.git vllm_source +cd vllm_source +``` + +Third, install required dependencies: + +```bash +uv pip install -r requirements/cpu-build.txt --torch-backend cpu +uv pip install -r requirements/cpu.txt --torch-backend cpu +``` + +??? console "pip" + ```bash + pip install --upgrade pip + pip install -v -r requirements/cpu-build.txt --extra-index-url https://download.pytorch.org/whl/cpu + pip install -v -r requirements/cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu + ``` + +Finally, build and install vLLM: + +```bash +VLLM_TARGET_DEVICE=cpu uv pip install . --no-build-isolation +``` + +If you want to develop vLLM, install it in editable mode instead. + +```bash +VLLM_TARGET_DEVICE=cpu uv pip install -e . --no-build-isolation +``` Testing has been conducted on AWS Graviton3 instances for compatibility. diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index f8b4f75308df..747035d38e3b 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -4,39 +4,39 @@ vLLM is a Python library that supports the following CPU variants. Select your C === "Intel/AMD x86" - --8<-- "docs/getting_started/installation/cpu/x86.inc.md:installation" + --8<-- "docs/getting_started/installation/cpu.x86.inc.md:installation" === "ARM AArch64" - --8<-- "docs/getting_started/installation/cpu/arm.inc.md:installation" + --8<-- "docs/getting_started/installation/cpu.arm.inc.md:installation" === "Apple silicon" - --8<-- "docs/getting_started/installation/cpu/apple.inc.md:installation" + --8<-- "docs/getting_started/installation/cpu.apple.inc.md:installation" === "IBM Z (S390X)" - --8<-- "docs/getting_started/installation/cpu/s390x.inc.md:installation" + --8<-- "docs/getting_started/installation/cpu.s390x.inc.md:installation" ## Requirements -- Python: 3.9 -- 3.12 +- Python: 3.10 -- 3.13 === "Intel/AMD x86" - --8<-- "docs/getting_started/installation/cpu/x86.inc.md:requirements" + --8<-- "docs/getting_started/installation/cpu.x86.inc.md:requirements" === "ARM AArch64" - --8<-- "docs/getting_started/installation/cpu/arm.inc.md:requirements" + --8<-- "docs/getting_started/installation/cpu.arm.inc.md:requirements" === "Apple silicon" - --8<-- "docs/getting_started/installation/cpu/apple.inc.md:requirements" + --8<-- "docs/getting_started/installation/cpu.apple.inc.md:requirements" === "IBM Z (S390X)" - --8<-- "docs/getting_started/installation/cpu/s390x.inc.md:requirements" + --8<-- "docs/getting_started/installation/cpu.s390x.inc.md:requirements" ## Set up using Python @@ -52,19 +52,19 @@ Currently, there are no pre-built CPU wheels. === "Intel/AMD x86" - --8<-- "docs/getting_started/installation/cpu/x86.inc.md:build-wheel-from-source" + --8<-- "docs/getting_started/installation/cpu.x86.inc.md:build-wheel-from-source" === "ARM AArch64" - --8<-- "docs/getting_started/installation/cpu/arm.inc.md:build-wheel-from-source" + --8<-- "docs/getting_started/installation/cpu.arm.inc.md:build-wheel-from-source" === "Apple silicon" - --8<-- "docs/getting_started/installation/cpu/apple.inc.md:build-wheel-from-source" + --8<-- "docs/getting_started/installation/cpu.apple.inc.md:build-wheel-from-source" === "IBM Z (s390x)" - --8<-- "docs/getting_started/installation/cpu/s390x.inc.md:build-wheel-from-source" + --8<-- "docs/getting_started/installation/cpu.s390x.inc.md:build-wheel-from-source" ## Set up using Docker @@ -72,24 +72,24 @@ Currently, there are no pre-built CPU wheels. === "Intel/AMD x86" - --8<-- "docs/getting_started/installation/cpu/x86.inc.md:pre-built-images" + --8<-- "docs/getting_started/installation/cpu.x86.inc.md:pre-built-images" ### Build image from source === "Intel/AMD x86" - --8<-- "docs/getting_started/installation/cpu/x86.inc.md:build-image-from-source" + --8<-- "docs/getting_started/installation/cpu.x86.inc.md:build-image-from-source" === "ARM AArch64" - --8<-- "docs/getting_started/installation/cpu/arm.inc.md:build-image-from-source" + --8<-- "docs/getting_started/installation/cpu.arm.inc.md:build-image-from-source" === "Apple silicon" - --8<-- "docs/getting_started/installation/cpu/arm.inc.md:build-image-from-source" + --8<-- "docs/getting_started/installation/cpu.arm.inc.md:build-image-from-source" === "IBM Z (S390X)" - --8<-- "docs/getting_started/installation/cpu/s390x.inc.md:build-image-from-source" + --8<-- "docs/getting_started/installation/cpu.s390x.inc.md:build-image-from-source" ## Related runtime environment variables diff --git a/docs/getting_started/installation/cpu/s390x.inc.md b/docs/getting_started/installation/cpu.s390x.inc.md similarity index 100% rename from docs/getting_started/installation/cpu/s390x.inc.md rename to docs/getting_started/installation/cpu.s390x.inc.md diff --git a/docs/getting_started/installation/cpu/x86.inc.md b/docs/getting_started/installation/cpu.x86.inc.md similarity index 100% rename from docs/getting_started/installation/cpu/x86.inc.md rename to docs/getting_started/installation/cpu.x86.inc.md diff --git a/docs/getting_started/installation/cpu/build.inc.md b/docs/getting_started/installation/cpu/build.inc.md deleted file mode 100644 index 4bd4d39a6f80..000000000000 --- a/docs/getting_started/installation/cpu/build.inc.md +++ /dev/null @@ -1,45 +0,0 @@ -First, install the recommended compiler. We recommend using `gcc/g++ >= 12.3.0` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run: - -```bash -sudo apt-get update -y -sudo apt-get install -y --no-install-recommends ccache git curl wget ca-certificates gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 jq lsof -sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 -``` - -Second, clone the vLLM project: - -```bash -git clone https://github.com/vllm-project/vllm.git vllm_source -cd vllm_source -``` - -Third, install required dependencies: - -```bash -uv pip install -r requirements/cpu-build.txt --torch-backend cpu -uv pip install -r requirements/cpu.txt --torch-backend cpu -``` - -??? console "pip" - ```bash - pip install --upgrade pip - pip install -v -r requirements/cpu-build.txt --extra-index-url https://download.pytorch.org/whl/cpu - pip install -v -r requirements/cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu - ``` - -Finally, build and install vLLM: - -```bash -VLLM_TARGET_DEVICE=cpu python setup.py install -``` - -If you want to develop vLLM, install it in editable mode instead. - -```bash -VLLM_TARGET_DEVICE=cpu python setup.py develop -``` - -!!! note - If you are building vLLM from source and not using the pre-built images, remember to set `LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD"` on x86 machines before running vLLM. - -# --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/google_tpu.md b/docs/getting_started/installation/google_tpu.md index 6f09babb3aba..0f8c5bccd4b9 100644 --- a/docs/getting_started/installation/google_tpu.md +++ b/docs/getting_started/installation/google_tpu.md @@ -153,11 +153,11 @@ VLLM_TARGET_DEVICE="tpu" python -m pip install -e . ### Pre-built images -See [deployment-docker-pre-built-image][deployment-docker-pre-built-image] for instructions on using the official Docker image, making sure to substitute the image name `vllm/vllm-openai` with `vllm/vllm-tpu`. +See [Using Docker](../../deployment/docker.md) for instructions on using the official Docker image, making sure to substitute the image name `vllm/vllm-openai` with `vllm/vllm-tpu`. ### Build image from source -You can use to build a Docker image with TPU support. +You can use [docker/Dockerfile.tpu](../../../docker/Dockerfile.tpu) to build a Docker image with TPU support. ```bash docker build -f docker/Dockerfile.tpu -t vllm-tpu . diff --git a/docs/getting_started/installation/gpu/cuda.inc.md b/docs/getting_started/installation/gpu.cuda.inc.md similarity index 94% rename from docs/getting_started/installation/gpu/cuda.inc.md rename to docs/getting_started/installation/gpu.cuda.inc.md index 9e64c6f2540a..b2d0d64a2d35 100644 --- a/docs/getting_started/installation/gpu/cuda.inc.md +++ b/docs/getting_started/installation/gpu.cuda.inc.md @@ -11,11 +11,11 @@ vLLM contains pre-compiled C++ and CUDA (12.8) binaries. # --8<-- [start:set-up-using-python] !!! note - PyTorch installed via `conda` will statically link `NCCL` library, which can cause issues when vLLM tries to use `NCCL`. See for more details. + PyTorch installed via `conda` will statically link `NCCL` library, which can cause issues when vLLM tries to use `NCCL`. See for more details. In order to be performant, vLLM has to compile many cuda kernels. The compilation unfortunately introduces binary incompatibility with other CUDA versions and PyTorch versions, even for the same PyTorch version with different building configurations. -Therefore, it is recommended to install vLLM with a **fresh new** environment. If either you have a different CUDA version or you want to use an existing PyTorch installation, you need to build vLLM from source. See [below][build-from-source] for more details. +Therefore, it is recommended to install vLLM with a **fresh new** environment. If either you have a different CUDA version or you want to use an existing PyTorch installation, you need to build vLLM from source. See [below](#build-wheel-from-source) for more details. # --8<-- [end:set-up-using-python] # --8<-- [start:pre-built-wheels] @@ -44,8 +44,6 @@ export CUDA_VERSION=118 # or 126 uv pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu${CUDA_VERSION}-cp38-abi3-manylinux1_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu${CUDA_VERSION} ``` -[](){ #install-the-latest-code } - #### Install the latest code LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides wheels for Linux running on an x86 platform with CUDA 12 for every commit since `v0.5.3`. @@ -128,11 +126,11 @@ export VLLM_PRECOMPILED_WHEEL_LOCATION=https://wheels.vllm.ai/${VLLM_COMMIT}/vll uv pip install --editable . ``` -You can find more information about vLLM's wheels in [install-the-latest-code][install-the-latest-code]. +You can find more information about vLLM's wheels in [Install the latest code](#install-the-latest-code). !!! note There is a possibility that your source code may have a different commit ID compared to the latest vLLM wheel, which could potentially lead to unknown errors. - It is recommended to use the same commit ID for the source code as the vLLM wheel you have installed. Please refer to [install-the-latest-code][install-the-latest-code] for instructions on how to install a specified wheel. + It is recommended to use the same commit ID for the source code as the vLLM wheel you have installed. Please refer to [Install the latest code](#install-the-latest-code) for instructions on how to install a specified wheel. #### Full build (with compilation) @@ -250,7 +248,7 @@ uv pip install -e . # --8<-- [end:build-wheel-from-source] # --8<-- [start:pre-built-images] -See [deployment-docker-pre-built-image][deployment-docker-pre-built-image] for instructions on using the official Docker image. +See [Using Docker](../../deployment/docker.md) for instructions on using the official Docker image. Another way to access the latest code is to use the docker images: @@ -266,11 +264,11 @@ The latest code can contain bugs and may not be stable. Please use it with cauti # --8<-- [end:pre-built-images] # --8<-- [start:build-image-from-source] -See [deployment-docker-build-image-from-source][deployment-docker-build-image-from-source] for instructions on building the Docker image. +See [Building vLLM's Docker Image from Source](../../deployment/docker.md#building-vllms-docker-image-from-source) for instructions on building the Docker image. # --8<-- [end:build-image-from-source] # --8<-- [start:supported-features] -See [feature-x-hardware][feature-x-hardware] compatibility matrix for feature support information. +See [Feature x Hardware](../../features/README.md#feature-x-hardware) compatibility matrix for feature support information. # --8<-- [end:supported-features] diff --git a/docs/getting_started/installation/gpu.md b/docs/getting_started/installation/gpu.md index e688cefea076..bc7508b29475 100644 --- a/docs/getting_started/installation/gpu.md +++ b/docs/getting_started/installation/gpu.md @@ -4,35 +4,35 @@ vLLM is a Python library that supports the following GPU variants. Select your G === "NVIDIA CUDA" - --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:installation" + --8<-- "docs/getting_started/installation/gpu.cuda.inc.md:installation" === "AMD ROCm" - --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:installation" + --8<-- "docs/getting_started/installation/gpu.rocm.inc.md:installation" === "Intel XPU" - --8<-- "docs/getting_started/installation/gpu/xpu.inc.md:installation" + --8<-- "docs/getting_started/installation/gpu.xpu.inc.md:installation" ## Requirements - OS: Linux -- Python: 3.9 -- 3.12 +- Python: 3.10 -- 3.13 !!! note vLLM does not support Windows natively. To run vLLM on Windows, you can use the Windows Subsystem for Linux (WSL) with a compatible Linux distribution, or use some community-maintained forks, e.g. [https://github.com/SystemPanic/vllm-windows](https://github.com/SystemPanic/vllm-windows). === "NVIDIA CUDA" - --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:requirements" + --8<-- "docs/getting_started/installation/gpu.cuda.inc.md:requirements" === "AMD ROCm" - --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:requirements" + --8<-- "docs/getting_started/installation/gpu.rocm.inc.md:requirements" === "Intel XPU" - --8<-- "docs/getting_started/installation/gpu/xpu.inc.md:requirements" + --8<-- "docs/getting_started/installation/gpu.xpu.inc.md:requirements" ## Set up using Python @@ -42,45 +42,43 @@ vLLM is a Python library that supports the following GPU variants. Select your G === "NVIDIA CUDA" - --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:set-up-using-python" + --8<-- "docs/getting_started/installation/gpu.cuda.inc.md:set-up-using-python" === "AMD ROCm" - --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:set-up-using-python" + --8<-- "docs/getting_started/installation/gpu.rocm.inc.md:set-up-using-python" === "Intel XPU" - --8<-- "docs/getting_started/installation/gpu/xpu.inc.md:set-up-using-python" + --8<-- "docs/getting_started/installation/gpu.xpu.inc.md:set-up-using-python" ### Pre-built wheels === "NVIDIA CUDA" - --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:pre-built-wheels" + --8<-- "docs/getting_started/installation/gpu.cuda.inc.md:pre-built-wheels" === "AMD ROCm" - --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:pre-built-wheels" + --8<-- "docs/getting_started/installation/gpu.rocm.inc.md:pre-built-wheels" === "Intel XPU" - --8<-- "docs/getting_started/installation/gpu/xpu.inc.md:pre-built-wheels" - -[](){ #build-from-source } + --8<-- "docs/getting_started/installation/gpu.xpu.inc.md:pre-built-wheels" ### Build wheel from source === "NVIDIA CUDA" - --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:build-wheel-from-source" + --8<-- "docs/getting_started/installation/gpu.cuda.inc.md:build-wheel-from-source" === "AMD ROCm" - --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:build-wheel-from-source" + --8<-- "docs/getting_started/installation/gpu.rocm.inc.md:build-wheel-from-source" === "Intel XPU" - --8<-- "docs/getting_started/installation/gpu/xpu.inc.md:build-wheel-from-source" + --8<-- "docs/getting_started/installation/gpu.xpu.inc.md:build-wheel-from-source" ## Set up using Docker @@ -88,40 +86,40 @@ vLLM is a Python library that supports the following GPU variants. Select your G === "NVIDIA CUDA" - --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:pre-built-images" + --8<-- "docs/getting_started/installation/gpu.cuda.inc.md:pre-built-images" === "AMD ROCm" - --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:pre-built-images" + --8<-- "docs/getting_started/installation/gpu.rocm.inc.md:pre-built-images" === "Intel XPU" - --8<-- "docs/getting_started/installation/gpu/xpu.inc.md:pre-built-images" + --8<-- "docs/getting_started/installation/gpu.xpu.inc.md:pre-built-images" ### Build image from source === "NVIDIA CUDA" - --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:build-image-from-source" + --8<-- "docs/getting_started/installation/gpu.cuda.inc.md:build-image-from-source" === "AMD ROCm" - --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:build-image-from-source" + --8<-- "docs/getting_started/installation/gpu.rocm.inc.md:build-image-from-source" === "Intel XPU" - --8<-- "docs/getting_started/installation/gpu/xpu.inc.md:build-image-from-source" + --8<-- "docs/getting_started/installation/gpu.xpu.inc.md:build-image-from-source" ## Supported features === "NVIDIA CUDA" - --8<-- "docs/getting_started/installation/gpu/cuda.inc.md:supported-features" + --8<-- "docs/getting_started/installation/gpu.cuda.inc.md:supported-features" === "AMD ROCm" - --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:supported-features" + --8<-- "docs/getting_started/installation/gpu.rocm.inc.md:supported-features" === "Intel XPU" - --8<-- "docs/getting_started/installation/gpu/xpu.inc.md:supported-features" + --8<-- "docs/getting_started/installation/gpu.xpu.inc.md:supported-features" diff --git a/docs/getting_started/installation/gpu/rocm.inc.md b/docs/getting_started/installation/gpu.rocm.inc.md similarity index 91% rename from docs/getting_started/installation/gpu/rocm.inc.md rename to docs/getting_started/installation/gpu.rocm.inc.md index 37c6647929b5..8abc5ac1c5c7 100644 --- a/docs/getting_started/installation/gpu/rocm.inc.md +++ b/docs/getting_started/installation/gpu.rocm.inc.md @@ -146,7 +146,7 @@ Building the Docker image from source is the recommended way to use vLLM with RO #### (Optional) Build an image with ROCm software stack -Build a docker image from which setup ROCm software stack needed by the vLLM. +Build a docker image from [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base) which setup ROCm software stack needed by the vLLM. **This step is optional as this rocm_base image is usually prebuilt and store at [Docker Hub](https://hub.docker.com/r/rocm/vllm-dev) under tag `rocm/vllm-dev:base` to speed up user experience.** If you choose to build this rocm_base image yourself, the steps are as follows. @@ -170,7 +170,7 @@ DOCKER_BUILDKIT=1 docker build \ #### Build an image with vLLM -First, build a docker image from and launch a docker container from the image. +First, build a docker image from [docker/Dockerfile.rocm](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm) and launch a docker container from the image. It is important that the user kicks off the docker build using buildkit. Either the user put `DOCKER_BUILDKIT=1` as environment variable when calling docker build command, or the user needs to set up buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: ```bash @@ -181,10 +181,10 @@ It is important that the user kicks off the docker build using buildkit. Either } ``` - uses ROCm 6.3 by default, but also supports ROCm 5.7, 6.0, 6.1, and 6.2, in older vLLM branches. +[docker/Dockerfile.rocm](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm) uses ROCm 6.3 by default, but also supports ROCm 5.7, 6.0, 6.1, and 6.2, in older vLLM branches. It provides flexibility to customize the build of docker image using the following arguments: -- `BASE_IMAGE`: specifies the base image used when running `docker build`. The default value `rocm/vllm-dev:base` is an image published and maintained by AMD. It is being built using +- `BASE_IMAGE`: specifies the base image used when running `docker build`. The default value `rocm/vllm-dev:base` is an image published and maintained by AMD. It is being built using [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base) - `ARG_PYTORCH_ROCM_ARCH`: Allows to override the gfx architecture values from the base docker image Their values can be passed in when running `docker build` with `--build-arg` options. @@ -217,6 +217,6 @@ Where the `` is the location where the model is stored, for examp # --8<-- [end:build-image-from-source] # --8<-- [start:supported-features] -See [feature-x-hardware][feature-x-hardware] compatibility matrix for feature support information. +See [Feature x Hardware](../../features/README.md#feature-x-hardware) compatibility matrix for feature support information. # --8<-- [end:supported-features] diff --git a/docs/getting_started/installation/gpu/xpu.inc.md b/docs/getting_started/installation/gpu.xpu.inc.md similarity index 94% rename from docs/getting_started/installation/gpu/xpu.inc.md rename to docs/getting_started/installation/gpu.xpu.inc.md index 2e73ac182569..9156df9db6df 100644 --- a/docs/getting_started/installation/gpu/xpu.inc.md +++ b/docs/getting_started/installation/gpu.xpu.inc.md @@ -75,7 +75,7 @@ vllm serve facebook/opt-13b \ -tp=8 ``` -By default, a ray instance will be launched automatically if no existing one is detected in the system, with `num-gpus` equals to `parallel_config.world_size`. We recommend properly starting a ray cluster before execution, referring to the helper script. +By default, a ray instance will be launched automatically if no existing one is detected in the system, with `num-gpus` equals to `parallel_config.world_size`. We recommend properly starting a ray cluster before execution, referring to the [examples/online_serving/run_cluster.sh](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/run_cluster.sh) helper script. # --8<-- [end:supported-features] # --8<-- [start:distributed-backend] diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md index 2af26626d207..70a91b7454ce 100644 --- a/docs/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -2,51 +2,73 @@ This guide will help you quickly get started with vLLM to perform: -- [Offline batched inference][quickstart-offline] -- [Online serving using OpenAI-compatible server][quickstart-online] +- [Offline batched inference](#offline-batched-inference) +- [Online serving using OpenAI-compatible server](#openai-compatible-server) ## Prerequisites - OS: Linux -- Python: 3.9 -- 3.13 +- Python: 3.10 -- 3.13 ## Installation -If you are using NVIDIA GPUs, you can install vLLM using [pip](https://pypi.org/project/vllm/) directly. +=== "NVIDIA CUDA" -It's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment and install vLLM using the following commands: + If you are using NVIDIA GPUs, you can install vLLM using [pip](https://pypi.org/project/vllm/) directly. -```bash -uv venv --python 3.12 --seed -source .venv/bin/activate -uv pip install vllm --torch-backend=auto -``` + It's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment and install vLLM using the following commands: + + ```bash + uv venv --python 3.12 --seed + source .venv/bin/activate + uv pip install vllm --torch-backend=auto + ``` -`uv` can [automatically select the appropriate PyTorch index at runtime](https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection) by inspecting the installed CUDA driver version via `--torch-backend=auto` (or `UV_TORCH_BACKEND=auto`). To select a specific backend (e.g., `cu126`), set `--torch-backend=cu126` (or `UV_TORCH_BACKEND=cu126`). + `uv` can [automatically select the appropriate PyTorch index at runtime](https://docs.astral.sh/uv/guides/integration/pytorch/#automatic-backend-selection) by inspecting the installed CUDA driver version via `--torch-backend=auto` (or `UV_TORCH_BACKEND=auto`). To select a specific backend (e.g., `cu126`), set `--torch-backend=cu126` (or `UV_TORCH_BACKEND=cu126`). -Another delightful way is to use `uv run` with `--with [dependency]` option, which allows you to run commands such as `vllm serve` without creating any permanent environment: + Another delightful way is to use `uv run` with `--with [dependency]` option, which allows you to run commands such as `vllm serve` without creating any permanent environment: -```bash -uv run --with vllm vllm --help -``` + ```bash + uv run --with vllm vllm --help + ``` -You can also use [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/getting-started.html) to create and manage Python environments. You can install `uv` to the conda environment through `pip` if you want to manage it within the environment. + You can also use [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/getting-started.html) to create and manage Python environments. You can install `uv` to the conda environment through `pip` if you want to manage it within the environment. -```bash -conda create -n myenv python=3.12 -y -conda activate myenv -pip install --upgrade uv -uv pip install vllm --torch-backend=auto -``` + ```bash + conda create -n myenv python=3.12 -y + conda activate myenv + pip install --upgrade uv + uv pip install vllm --torch-backend=auto + ``` + +=== "AMD ROCm" + + Use a pre-built docker image from Docker Hub. The public stable image is [rocm/vllm:latest](https://hub.docker.com/r/rocm/vllm). There is also a development image at [rocm/vllm-dev](https://hub.docker.com/r/rocm/vllm-dev). + + The `-v` flag in the `docker run` command below mounts a local directory into the container. Replace `` with the path on your host machine to the directory containing your models. The models will then be accessible inside the container at `/app/models`. + + ???+ console "Commands" + ```bash + docker pull rocm/vllm-dev:nightly # to get the latest image + docker run -it --rm \ + --network=host \ + --group-add=video \ + --ipc=host \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --device /dev/kfd \ + --device /dev/dri \ + -v :/app/models \ + -e HF_HOME="/app/models" \ + rocm/vllm-dev:nightly + ``` !!! note For more detail and non-CUDA platforms, please refer [here](installation/README.md) for specific instructions on how to install vLLM. -[](){ #quickstart-offline } - ## Offline Batched Inference -With vLLM installed, you can start generating texts for list of input prompts (i.e. offline batch inferencing). See the example script: +With vLLM installed, you can start generating texts for list of input prompts (i.e. offline batch inferencing). See the example script: [examples/offline_inference/basic/basic.py](../../examples/offline_inference/basic/basic.py) The first line of this example imports the classes [LLM][vllm.LLM] and [SamplingParams][vllm.SamplingParams]: @@ -57,7 +79,7 @@ The first line of this example imports the classes [LLM][vllm.LLM] and [Sampling from vllm import LLM, SamplingParams ``` -The next section defines a list of input prompts and sampling parameters for text generation. The [sampling temperature](https://arxiv.org/html/2402.05201v1) is set to `0.8` and the [nucleus sampling probability](https://en.wikipedia.org/wiki/Top-p_sampling) is set to `0.95`. You can find more information about the sampling parameters [here][sampling-params]. +The next section defines a list of input prompts and sampling parameters for text generation. The [sampling temperature](https://arxiv.org/html/2402.05201v1) is set to `0.8` and the [nucleus sampling probability](https://en.wikipedia.org/wiki/Top-p_sampling) is set to `0.95`. You can find more information about the sampling parameters [here](../api/README.md#inference-parameters). !!! important By default, vLLM will use sampling parameters recommended by model creator by applying the `generation_config.json` from the Hugging Face model repository if it exists. In most cases, this will provide you with the best results by default if [SamplingParams][vllm.SamplingParams] is not specified. @@ -135,8 +157,6 @@ for output in outputs: print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` -[](){ #quickstart-online } - ## OpenAI-Compatible Server vLLM can be deployed as a server that implements the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API. @@ -150,7 +170,7 @@ vllm serve Qwen/Qwen2.5-1.5B-Instruct !!! note By default, the server uses a predefined chat template stored in the tokenizer. - You can learn about overriding it [here][chat-template]. + You can learn about overriding it [here](../serving/openai_compatible_server.md#chat-template). !!! important By default, the server applies `generation_config.json` from the huggingface model repository if it exists. This means the default values of certain sampling parameters can be overridden by those recommended by the model creator. @@ -194,12 +214,14 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep api_key=openai_api_key, base_url=openai_api_base, ) - completion = client.completions.create(model="Qwen/Qwen2.5-1.5B-Instruct", - prompt="San Francisco is a") + completion = client.completions.create( + model="Qwen/Qwen2.5-1.5B-Instruct", + prompt="San Francisco is a", + ) print("Completion result:", completion) ``` -A more detailed client example can be found here: +A more detailed client example can be found here: [examples/offline_inference/basic/basic.py](../../examples/offline_inference/basic/basic.py) ### OpenAI Chat Completions API with vLLM @@ -239,7 +261,7 @@ Alternatively, you can use the `openai` Python package: messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Tell me a joke."}, - ] + ], ) print("Chat response:", chat_response) ``` @@ -248,7 +270,17 @@ Alternatively, you can use the `openai` Python package: Currently, vLLM supports multiple backends for efficient Attention computation across different platforms and accelerator architectures. It automatically selects the most performant backend compatible with your system and model specifications. -If desired, you can also manually set the backend of your choice by configuring the environment variable `VLLM_ATTENTION_BACKEND` to one of the following options: `FLASH_ATTN`, `FLASHINFER` or `XFORMERS`. +If desired, you can also manually set the backend of your choice by configuring the environment variable `VLLM_ATTENTION_BACKEND` to one of the following options: + +- On NVIDIA CUDA: `FLASH_ATTN`, `FLASHINFER` or `XFORMERS`. +- On AMD ROCm: `TRITON_ATTN`, `ROCM_ATTN`, `ROCM_AITER_FA` or `ROCM_AITER_UNIFIED_ATTN`. + +For AMD ROCm, you can futher control the specific Attention implementation using the following variables: + +- Triton Unified Attention: `VLLM_ROCM_USE_AITER=0 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=0` +- AITER Unified Attention: `VLLM_ROCM_USE_AITER=1 VLLM_USE_AITER_UNIFIED_ATTENTION=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=0` +- Triton Prefill-Decode Attention: `VLLM_ROCM_USE_AITER=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 VLLM_ROCM_USE_AITER_MHA=0` +- AITER Multi-head Attention: `VLLM_ROCM_USE_AITER=1 VLLM_V1_USE_PREFILL_DECODE_ATTENTION=0 VLLM_ROCM_USE_AITER_MHA=1` !!! warning - There are no pre-built vllm wheels containing Flash Infer, so you must install it in your environment first. Refer to the [Flash Infer official docs](https://docs.flashinfer.ai/) or see for instructions on how to install it. + There are no pre-built vllm wheels containing Flash Infer, so you must install it in your environment first. Refer to the [Flash Infer official docs](https://docs.flashinfer.ai/) or see [docker/Dockerfile](../../docker/Dockerfile) for instructions on how to install it. diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py index ecd71ee1f3f6..a4da5b933e15 100644 --- a/docs/mkdocs/hooks/generate_argparse.py +++ b/docs/mkdocs/hooks/generate_argparse.py @@ -22,6 +22,11 @@ class PydanticMagicMock(MagicMock): """`MagicMock` that's able to generate pydantic-core schemas.""" + def __init__(self, *args, **kwargs): + name = kwargs.pop("name", None) + super().__init__(*args, **kwargs) + self.__spec__ = importlib.machinery.ModuleSpec(name, None) + def __get_pydantic_core_schema__(self, source_type, handler): return core_schema.any_schema() @@ -42,7 +47,9 @@ def auto_mock(module, attr, max_mocks=50): raise e except ModuleNotFoundError as e: logger.info("Mocking %s for argparse doc generation", e.name) - sys.modules[e.name] = PydanticMagicMock() + sys.modules[e.name] = PydanticMagicMock(name=e.name) + except Exception as e: + logger.warning("Failed to import %s.%s: %s", module, attr, e) raise ImportError( f"Failed to import {module}.{attr} after mocking {max_mocks} imports" diff --git a/docs/mkdocs/hooks/generate_examples.py b/docs/mkdocs/hooks/generate_examples.py index ed8277f628d4..6e4fb039e3a0 100644 --- a/docs/mkdocs/hooks/generate_examples.py +++ b/docs/mkdocs/hooks/generate_examples.py @@ -137,13 +137,20 @@ def replace_link(match): gh_file = (self.main_file.parent / relative_path).resolve() gh_file = gh_file.relative_to(ROOT_DIR) - return f"[{link_text}](gh-file:{gh_file})" + # Make GitHub URL + url = "https://github.com/vllm-project/vllm/" + url += "tree/main" if self.path.is_dir() else "blob/main" + gh_url = f"{url}/{gh_file}" + + return f"[{link_text}]({gh_url})" return re.sub(link_pattern, replace_link, content) def generate(self) -> str: content = f"# {self.title}\n\n" - content += f"Source .\n\n" + url = "https://github.com/vllm-project/vllm/" + url += "tree/main" if self.path.is_dir() else "blob/main" + content += f"Source <{url}/{self.path.relative_to(ROOT_DIR)}>.\n\n" # Use long code fence to avoid issues with # included files containing code fences too diff --git a/docs/mkdocs/hooks/url_schemes.py b/docs/mkdocs/hooks/url_schemes.py index 53b1fbca26b9..f36a64ed7a3b 100644 --- a/docs/mkdocs/hooks/url_schemes.py +++ b/docs/mkdocs/hooks/url_schemes.py @@ -1,123 +1,95 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -This is basically a port of MyST parser’s external URL resolution mechanism -(https://myst-parser.readthedocs.io/en/latest/syntax/cross-referencing.html#customising-external-url-resolution) -to work with MkDocs. +MkDocs hook to enable the following links to render correctly: -It allows Markdown authors to use GitHub shorthand links like: - - - [Text](gh-issue:123) - - - - [File](gh-file:path/to/file.py#L10) - -These are automatically rewritten into fully qualified GitHub URLs pointing to -issues, pull requests, files, directories, or projects in the -`vllm-project/vllm` repository. +- Relative file links outside of the `docs/` directory, e.g.: + - [Text](../some_file.py) + - [Directory](../../some_directory/) +- GitHub URLs for issues, pull requests, and projects, e.g.: + - Adds GitHub icon before links + - Replaces raw links with descriptive text, + e.g. <...pull/123> -> [Pull Request #123](.../pull/123) + - Works for external repos too by including the `owner/repo` in the link title The goal is to simplify cross-referencing common GitHub resources in project docs. """ +from pathlib import Path + import regex as re from mkdocs.config.defaults import MkDocsConfig from mkdocs.structure.files import Files from mkdocs.structure.pages import Page +ROOT_DIR = Path(__file__).parent.parent.parent.parent.resolve() +DOC_DIR = ROOT_DIR / "docs" + + +gh_icon = ":octicons-mark-github-16:" + +# Regex pieces +TITLE = r"(?P[^\[\]<>]+?)" +REPO = r"(?P<repo>.+?/.+?)" +TYPE = r"(?P<type>issues|pull|projects)" +NUMBER = r"(?P<number>\d+)" +FRAGMENT = r"(?P<fragment>#[^\s]+)?" +URL = f"https://github.com/{REPO}/{TYPE}/{NUMBER}{FRAGMENT}" +RELATIVE = r"(?!(https?|ftp)://|#)(?P<path>[^\s]+?)" + +# Common titles to use for GitHub links when none is provided in the link. +TITLES = {"issues": "Issue ", "pull": "Pull Request ", "projects": "Project "} + +# Regex to match GitHub issue, PR, and project links with optional titles. +github_link = re.compile(rf"(\[{TITLE}\]\(|<){URL}(\)|>)") +# Regex to match relative file links with optional titles. +relative_link = re.compile(rf"\[{TITLE}\]\({RELATIVE}\)") + def on_page_markdown( markdown: str, *, page: Page, config: MkDocsConfig, files: Files ) -> str: - """ - Custom MkDocs plugin hook to rewrite special GitHub reference links - in Markdown. - - This function scans the given Markdown content for specially formatted - GitHub shorthand links, such as: - - `[Link text](gh-issue:123)` - - `<gh-pr:456>` - - And rewrites them into fully-qualified GitHub URLs with GitHub icons: - - `[:octicons-mark-github-16: Link text](https://github.com/vllm-project/vllm/issues/123)` - - `[:octicons-mark-github-16: Pull Request #456](https://github.com/vllm-project/vllm/pull/456)` - - Supported shorthand types: - - `gh-issue` - - `gh-pr` - - `gh-project` - - `gh-dir` - - `gh-file` - - Args: - markdown (str): The raw Markdown content of the page. - page (Page): The MkDocs page object being processed. - config (MkDocsConfig): The MkDocs site configuration. - files (Files): The collection of files in the MkDocs build. - - Returns: - str: The updated Markdown content with GitHub shorthand links replaced. - """ - gh_icon = ":octicons-mark-github-16:" - gh_url = "https://github.com" - repo_url = f"{gh_url}/vllm-project/vllm" - org_url = f"{gh_url}/orgs/vllm-project" - - # Mapping of shorthand types to their corresponding GitHub base URLs - urls = { - "issue": f"{repo_url}/issues", - "pr": f"{repo_url}/pull", - "project": f"{org_url}/projects", - "dir": f"{repo_url}/tree/main", - "file": f"{repo_url}/blob/main", - } - - # Default title prefixes for auto links - titles = { - "issue": "Issue #", - "pr": "Pull Request #", - "project": "Project #", - "dir": "", - "file": "", - } - - # Regular expression to match GitHub shorthand links - scheme = r"gh-(?P<type>.+?):(?P<path>.+?)(#(?P<fragment>.+?))?" - inline_link = re.compile(r"\[(?P<title>[^\[]+?)\]\(" + scheme + r"\)") - auto_link = re.compile(f"<{scheme}>") - - def replace_inline_link(match: re.Match) -> str: - """ - Replaces a matched inline-style GitHub shorthand link - with a full Markdown link. - - Example: - [My issue](gh-issue:123) → [:octicons-mark-github-16: My issue](https://github.com/vllm-project/vllm/issues/123) - """ - url = f"{urls[match.group('type')]}/{match.group('path')}" - if fragment := match.group("fragment"): - url += f"#{fragment}" - - return f"[{gh_icon} {match.group('title')}]({url})" - - def replace_auto_link(match: re.Match) -> str: - """ - Replaces a matched autolink-style GitHub shorthand - with a full Markdown link. - - Example: - <gh-pr:456> → [:octicons-mark-github-16: Pull Request #456](https://github.com/vllm-project/vllm/pull/456) - """ - type = match.group("type") + def replace_relative_link(match: re.Match) -> str: + """Replace relative file links with URLs if they point outside the docs dir.""" + title = match.group("title") path = match.group("path") - title = f"{titles[type]}{path}" - url = f"{urls[type]}/{path}" - if fragment := match.group("fragment"): - url += f"#{fragment}" + path = (Path(page.file.abs_src_path).parent / path).resolve() + + # Check if the path exists and is outside the docs dir + if not path.exists() or path.is_relative_to(DOC_DIR): + return match.group(0) + + # Files and directories have different URL schemes on GitHub + slug = "tree/main" if path.is_dir() else "blob/main" + path = path.relative_to(ROOT_DIR) + url = f"https://github.com/vllm-project/vllm/{slug}/{path}" return f"[{gh_icon} {title}]({url})" - # Replace both inline and autolinks - markdown = inline_link.sub(replace_inline_link, markdown) - markdown = auto_link.sub(replace_auto_link, markdown) + def replace_github_link(match: re.Match) -> str: + """Replace GitHub issue, PR, and project links with enhanced Markdown links.""" + repo = match.group("repo") + type = match.group("type") + number = match.group("number") + # Title and fragment could be None + title = match.group("title") or "" + fragment = match.group("fragment") or "" + + # Use default titles for raw links + if not title: + title = TITLES[type] + if "vllm-project" not in repo: + title += repo + title += f"#{number}" + + url = f"https://github.com/{repo}/{type}/{number}{fragment}" + return f"[{gh_icon} {title}]({url})" + + markdown = relative_link.sub(replace_relative_link, markdown) + markdown = github_link.sub(replace_github_link, markdown) + + if "interface" in str(page.file.abs_src_path): + print(markdown) return markdown diff --git a/docs/models/extensions/fastsafetensor.md b/docs/models/extensions/fastsafetensor.md index 2a5a18102dc2..0f30d4e2f69d 100644 --- a/docs/models/extensions/fastsafetensor.md +++ b/docs/models/extensions/fastsafetensor.md @@ -3,4 +3,4 @@ Loading Model weights with fastsafetensors Using fastsafetensors library enables loading model weights to GPU memory by leveraging GPU direct storage. See [their GitHub repository](https://github.com/foundation-model-stack/fastsafetensors) for more details. -To enable this feature, use the ``--load-format fastsafetensors`` command-line argument +To enable this feature, use the `--load-format fastsafetensors` command-line argument diff --git a/docs/models/extensions/runai_model_streamer.md b/docs/models/extensions/runai_model_streamer.md index 8a97a49825a4..c2cf107263a0 100644 --- a/docs/models/extensions/runai_model_streamer.md +++ b/docs/models/extensions/runai_model_streamer.md @@ -82,7 +82,7 @@ vllm serve /path/to/sharded/model \ --model-loader-extra-config '{"pattern":"custom-model-rank-{rank}-part-{part}.safetensors"}' ``` -To create sharded model files, you can use the script provided in <gh-file:examples/offline_inference/save_sharded_state.py>. This script demonstrates how to save a model in the sharded format that is compatible with the Run:ai Model Streamer sharded loader. +To create sharded model files, you can use the script provided in [examples/offline_inference/save_sharded_state.py](../../../examples/offline_inference/save_sharded_state.py). This script demonstrates how to save a model in the sharded format that is compatible with the Run:ai Model Streamer sharded loader. The sharded loader supports all the same tunable parameters as the regular Run:ai Model Streamer, including `concurrency` and `memory_limit`. These can be configured in the same way: diff --git a/docs/models/extensions/tensorizer.md b/docs/models/extensions/tensorizer.md index f70ab0c6f4e5..3df80d5af6c4 100644 --- a/docs/models/extensions/tensorizer.md +++ b/docs/models/extensions/tensorizer.md @@ -60,7 +60,7 @@ from vllm import LLM llm = LLM( "s3://my-bucket/vllm/facebook/opt-125m/v1", load_format="tensorizer", - enable_lora=True + enable_lora=True, ) ``` @@ -97,6 +97,6 @@ llm = LLM( "s3://my-bucket/vllm/facebook/opt-125m/v1", load_format="tensorizer", enable_lora=True, - model_loader_extra_config={"deserialization_kwargs": {"num_readers": 2}} + model_loader_extra_config={"deserialization_kwargs": {"num_readers": 2}}, ) ``` diff --git a/docs/models/generative_models.md b/docs/models/generative_models.md index 05f8d16cc4ca..be2f25bf0661 100644 --- a/docs/models/generative_models.md +++ b/docs/models/generative_models.md @@ -59,7 +59,7 @@ for output in outputs: By default, vLLM will use sampling parameters recommended by model creator by applying the `generation_config.json` from the huggingface model repository if it exists. In most cases, this will provide you with the best results by default if [SamplingParams][vllm.SamplingParams] is not specified. However, if vLLM's default sampling parameters are preferred, please pass `generation_config="vllm"` when creating the [LLM][vllm.LLM] instance. -A code example can be found here: <gh-file:examples/offline_inference/basic/basic.py> +A code example can be found here: [examples/offline_inference/basic/basic.py](../../examples/offline_inference/basic/basic.py) ### `LLM.beam_search` @@ -98,15 +98,15 @@ and automatically applies the model's [chat template](https://huggingface.co/doc conversation = [ { "role": "system", - "content": "You are a helpful assistant" + "content": "You are a helpful assistant", }, { "role": "user", - "content": "Hello" + "content": "Hello", }, { "role": "assistant", - "content": "Hello! How can I assist you today?" + "content": "Hello! How can I assist you today?", }, { "role": "user", @@ -121,7 +121,7 @@ and automatically applies the model's [chat template](https://huggingface.co/doc print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` -A code example can be found here: <gh-file:examples/offline_inference/basic/chat.py> +A code example can be found here: [examples/offline_inference/basic/chat.py](../../examples/offline_inference/basic/chat.py) If the model doesn't have a chat template or you want to specify another one, you can explicitly pass a chat template: @@ -140,5 +140,5 @@ outputs = llm.chat(conversation, chat_template=custom_template) Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides endpoints that correspond to the offline APIs: -- [Completions API][completions-api] is similar to `LLM.generate` but only accepts text. -- [Chat API][chat-api] is similar to `LLM.chat`, accepting both text and [multi-modal inputs](../features/multimodal_inputs.md) for models with a chat template. +- [Completions API](../serving/openai_compatible_server.md#completions-api) is similar to `LLM.generate` but only accepts text. +- [Chat API](../serving/openai_compatible_server.md#chat-api) is similar to `LLM.chat`, accepting both text and [multi-modal inputs](../features/multimodal_inputs.md) for models with a chat template. diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index 50982d3d0d0f..40651be1d449 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -9,7 +9,7 @@ before returning them. !!! note We currently support pooling models primarily as a matter of convenience. This is not guaranteed to have any performance improvement over using HF Transformers / Sentence Transformers directly. - We are now planning to optimize pooling models in vLLM. Please comment on <gh-issue:21796> if you have any suggestions! + We are now planning to optimize pooling models in vLLM. Please comment on <https://github.com/vllm-project/vllm/issues/21796> if you have any suggestions! ## Configuration @@ -98,7 +98,7 @@ embeds = output.outputs.embedding print(f"Embeddings: {embeds!r} (size={len(embeds)})") ``` -A code example can be found here: <gh-file:examples/offline_inference/basic/embed.py> +A code example can be found here: [examples/offline_inference/basic/embed.py](../../examples/offline_inference/basic/embed.py) ### `LLM.classify` @@ -115,7 +115,7 @@ probs = output.outputs.probs print(f"Class Probabilities: {probs!r} (size={len(probs)})") ``` -A code example can be found here: <gh-file:examples/offline_inference/basic/classify.py> +A code example can be found here: [examples/offline_inference/basic/classify.py](../../examples/offline_inference/basic/classify.py) ### `LLM.score` @@ -130,14 +130,16 @@ It is designed for embedding models and cross-encoder models. Embedding models u from vllm import LLM llm = LLM(model="BAAI/bge-reranker-v2-m3", runner="pooling") -(output,) = llm.score("What is the capital of France?", - "The capital of Brazil is Brasilia.") +(output,) = llm.score( + "What is the capital of France?", + "The capital of Brazil is Brasilia.", +) score = output.outputs.score print(f"Score: {score}") ``` -A code example can be found here: <gh-file:examples/offline_inference/basic/score.py> +A code example can be found here: [examples/offline_inference/basic/score.py](../../examples/offline_inference/basic/score.py) ### `LLM.reward` @@ -154,7 +156,7 @@ data = output.outputs.data print(f"Data: {data!r}") ``` -A code example can be found here: <gh-file:examples/offline_inference/basic/reward.py> +A code example can be found here: [examples/offline_inference/basic/reward.py](../../examples/offline_inference/basic/reward.py) ### `LLM.encode` @@ -183,10 +185,10 @@ print(f"Data: {data!r}") Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides endpoints that correspond to the offline APIs: -- [Pooling API][pooling-api] is similar to `LLM.encode`, being applicable to all types of pooling models. -- [Embeddings API][embeddings-api] is similar to `LLM.embed`, accepting both text and [multi-modal inputs](../features/multimodal_inputs.md) for embedding models. -- [Classification API][classification-api] is similar to `LLM.classify` and is applicable to sequence classification models. -- [Score API][score-api] is similar to `LLM.score` for cross-encoder models. +- [Pooling API](../serving/openai_compatible_server.md#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models. +- [Embeddings API](../serving/openai_compatible_server.md#embeddings-api) is similar to `LLM.embed`, accepting both text and [multi-modal inputs](../features/multimodal_inputs.md) for embedding models. +- [Classification API](../serving/openai_compatible_server.md#classification-api) is similar to `LLM.classify` and is applicable to sequence classification models. +- [Score API](../serving/openai_compatible_server.md#score-api) is similar to `LLM.score` for cross-encoder models. ## Matryoshka Embeddings @@ -209,7 +211,7 @@ For models that support Matryoshka Embeddings but not recognized by vLLM, please Here is an example to serve a model with Matryoshka Embeddings enabled. -```text +```bash vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf-overrides '{"matryoshka_dimensions":[256]}' ``` @@ -220,27 +222,31 @@ You can change the output dimensions of embedding models that support Matryoshka ```python from vllm import LLM, PoolingParams -llm = LLM(model="jinaai/jina-embeddings-v3", - runner="pooling", - trust_remote_code=True) -outputs = llm.embed(["Follow the white rabbit."], - pooling_params=PoolingParams(dimensions=32)) +llm = LLM( + model="jinaai/jina-embeddings-v3", + runner="pooling", + trust_remote_code=True, +) +outputs = llm.embed( + ["Follow the white rabbit."], + pooling_params=PoolingParams(dimensions=32), +) print(outputs[0].outputs) ``` -A code example can be found here: <gh-file:examples/offline_inference/pooling/embed_matryoshka_fy.py> +A code example can be found here: [examples/offline_inference/pooling/embed_matryoshka_fy.py](../../examples/offline_inference/pooling/embed_matryoshka_fy.py) ### Online Inference Use the following command to start vllm server. -```text +```bash vllm serve jinaai/jina-embeddings-v3 --trust-remote-code ``` You can change the output dimensions of embedding models that support Matryoshka Embeddings by using the dimensions parameter. -```text +```bash curl http://127.0.0.1:8000/v1/embeddings \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ @@ -258,4 +264,4 @@ Expected output: {"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}} ``` -An OpenAI client example can be found here: <gh-file:examples/online_serving/pooling/openai_embedding_matryoshka_fy.py> +An OpenAI client example can be found here: [examples/online_serving/pooling/openai_embedding_matryoshka_fy.py](../../examples/online_serving/pooling/openai_embedding_matryoshka_fy.py) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 5ac8f2121f97..d726b5350f80 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -9,11 +9,9 @@ Alongside each architecture, we include some popular models that use it. ### vLLM -If vLLM natively supports a model, its implementation can be found in <gh-file:vllm/model_executor/models>. +If vLLM natively supports a model, its implementation can be found in [vllm/model_executor/models](../../vllm/model_executor/models). -These models are what we list in [supported-text-models][supported-text-models] and [supported-mm-models][supported-mm-models]. - -[](){ #transformers-backend } +These models are what we list in [supported text models](#list-of-text-only-language-models) and [supported multimodal models](#list-of-multimodal-language-models). ### Transformers @@ -60,7 +58,7 @@ For a model to be compatible with the Transformers backend for vLLM it must: - be a Transformers compatible custom model (see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)): - The model directory must have the correct structure (e.g. `config.json` is present). - `config.json` must contain `auto_map.AutoModel`. -- be a Transformers backend for vLLM compatible model (see [writing-custom-models][writing-custom-models]): +- be a Transformers backend for vLLM compatible model (see [Writing custom models](#writing-custom-models)): - Customisation should be done in the base model (e.g. in `MyModel`, not `MyModelForCausalLM`). If the compatible model is: @@ -70,8 +68,6 @@ If the compatible model is: This means that, with the Transformers backend for vLLM, new models can be used before they are officially supported in Transformers or vLLM! -[](){ #writing-custom-models } - #### Writing custom models This section details the necessary modifications to make to a Transformers compatible custom model that make it compatible with the Transformers backend for vLLM. (We assume that a Transformers compatible custom model has already been created, see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)). @@ -116,7 +112,7 @@ Here is what happens in the background when this model is loaded: 1. The config is loaded. 2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`. -3. `MyModel` is loaded into one of the Transformers backend classes in <gh-file:vllm/model_executor/models/transformers.py> which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. +3. `MyModel` is loaded into one of the Transformers backend classes in [vllm/model_executor/models/transformers](../../vllm/model_executor/models/transformers) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. That's it! @@ -164,7 +160,7 @@ To determine whether a given model is natively supported, you can check the `con If the `"architectures"` field contains a model architecture listed below, then it should be natively supported. Models do not _need_ to be natively supported to be used in vLLM. -The [Transformers backend][transformers-backend] enables you to run models directly using their Transformers implementation (or even remote code on the Hugging Face Model Hub!). +The [Transformers backend](#transformers) enables you to run models directly using their Transformers implementation (or even remote code on the Hugging Face Model Hub!). !!! tip The easiest way to check if your model is really supported at runtime is to run the program below: @@ -278,8 +274,8 @@ https_proxy=http://your.proxy.server:port vllm serve <model_name> ```python import os -os.environ['http_proxy'] = 'http://your.proxy.server:port' -os.environ['https_proxy'] = 'http://your.proxy.server:port' +os.environ["http_proxy"] = "http://your.proxy.server:port" +os.environ["https_proxy"] = "http://your.proxy.server:port" ``` ### ModelScope @@ -306,8 +302,6 @@ output = llm.encode("Hello, my name is") print(output) ``` -[](){ #feature-status-legend } - ## Feature Status Legend - ✅︎ indicates that the feature is supported for the model. @@ -316,8 +310,6 @@ print(output) - ⚠️ indicates that the feature is available but may have known issues or limitations. -[](){ #supported-text-models } - ## List of Text-only Language Models ### Generative Models @@ -335,107 +327,108 @@ th { } </style> -| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | -|--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `ApertusForCausalLM` | Apertus | `swiss-ai/Apertus-8B-2509`, `swiss-ai/Apertus-70B-Instruct-2509`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `ArceeForCausalLM` | Arcee (AFM) | `arcee-ai/AFM-4.5B-Base`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ | -| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `BailingMoeV2ForCausalLM` | Ling | `inclusionAI/Ling-mini-2.0`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | -| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | ✅︎ | -| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R, Command-A | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, `CohereLabs/c4ai-command-a-03-2025`, `CohereLabs/command-a-reasoning-08-2025`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ | -| `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | ✅︎ | -| `DotsOCRForCausalLM` | dots_ocr | `rednote-hilab/dots.ocr` | | ✅︎ | ✅︎ | -| `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | ✅︎ | -| `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Exaone4ForCausalLM` | EXAONE-4 | `LGAI-EXAONE/EXAONE-4.0-32B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Fairseq2LlamaForCausalLM` | Llama (fairseq2 format) | `mgleize/fairseq2-dummy-Llama-3.2-1B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ | -| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ | -| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Gemma3nForCausalLM` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | -| `GlmForCausalLM` | GLM-4 | `zai-org/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Glm4ForCausalLM` | GLM-4-0414 | `zai-org/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Glm4MoeForCausalLM` | GLM-4.5, GLM-4.6 | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | ✅︎ | -| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | ✅︎ | -| `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | ✅︎ | -| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | | ✅︎ | ✅︎ | -| `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ | -| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ | -| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ | -| `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | ✅︎ | ✅︎ | -| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | | ✅︎ | -| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ | -| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Lfm2MoeForCausalLM` | LFM2MoE | `LiquidAI/LFM2-8B-A1B-preview`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ | -| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ | -| `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ | -| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `OLMo3ForCausalLM` | OLMo3 | TBA | ✅︎ | ✅︎ | ✅︎ | -| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | -| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ | -| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ | -| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | ✅︎ | -| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen3NextForCausalLM` | Qwen3NextMoE | `Qwen/Qwen3-Next-80B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ | -| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ | -| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | ✅︎ | -| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | ✅︎ | -| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | ✅︎ | -| `LongcatFlashForCausalLM` | LongCat-Flash | `meituan-longcat/LongCat-Flash-Chat`, `meituan-longcat/LongCat-Flash-Chat-FP8` | ✅︎ |✅︎ | ✅︎ | +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | +|--------------|--------|-------------------|----------------------|---------------------------| +| `ApertusForCausalLM` | Apertus | `swiss-ai/Apertus-8B-2509`, `swiss-ai/Apertus-70B-Instruct-2509`, etc. | ✅︎ | ✅︎ | +| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | +| `ArceeForCausalLM` | Arcee (AFM) | `arcee-ai/AFM-4.5B-Base`, etc. | ✅︎ | ✅︎ | +| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | +| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | +| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | +| `BailingMoeV2ForCausalLM` | Ling | `inclusionAI/Ling-mini-2.0`, etc. | ✅︎ | ✅︎ | +| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | +| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | +| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | +| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R, Command-A | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, `CohereLabs/c4ai-command-a-03-2025`, `CohereLabs/command-a-reasoning-08-2025`, etc. | ✅︎ | ✅︎ | +| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | +| `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | +| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | ✅︎ | ✅︎ | +| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | ✅︎ | ✅︎ | +| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`, etc. | ✅︎ | ✅︎ | +| `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | +| `DotsOCRForCausalLM` | dots_ocr | `rednote-hilab/dots.ocr` | | ✅︎ | +| `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | +| `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | +| `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | +| `Exaone4ForCausalLM` | EXAONE-4 | `LGAI-EXAONE/EXAONE-4.0-32B`, etc. | ✅︎ | ✅︎ | +| `Fairseq2LlamaForCausalLM` | Llama (fairseq2 format) | `mgleize/fairseq2-dummy-Llama-3.2-1B`, etc. | ✅︎ | ✅︎ | +| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | +| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | +| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | +| `FlexOlmoForCausalLM` | FlexOlmo | `allenai/FlexOlmo-7x7B-1T`, `allenai/FlexOlmo-7x7B-1T-RT`, etc. | | ✅︎ | +| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | +| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | +| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | +| `Gemma3nForCausalLM` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | +| `GlmForCausalLM` | GLM-4 | `zai-org/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | +| `Glm4ForCausalLM` | GLM-4-0414 | `zai-org/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | +| `Glm4MoeForCausalLM` | GLM-4.5, GLM-4.6 | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | +| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | +| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | +| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | +| `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | +| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | | ✅︎ | +| `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | +| `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | +| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | +| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | +| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | +| `HunYuanDenseV1ForCausalLM` | Hunyuan Dense | `tencent/Hunyuan-7B-Instruct` | ✅︎ | ✅︎ | +| `HunYuanMoEV1ForCausalLM` | Hunyuan-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ | +| `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | | +| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | +| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | +| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | +| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | +| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | +| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | +| `Lfm2MoeForCausalLM` | LFM2MoE | `LiquidAI/LFM2-8B-A1B-preview`, etc. | ✅︎ | ✅︎ | +| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | +| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | +| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | +| `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ | +| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | +| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | +| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | +| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | +| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | +| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | +| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | +| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | +| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | +| `OLMo3ForCausalLM` | OLMo3 | TBA | ✅︎ | ✅︎ | +| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | +| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | ✅︎ | ✅︎ | +| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | +| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | +| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | +| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | +| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | +| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | ✅︎ | +| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | +| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | +| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | +| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | +| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | +| `Qwen3NextForCausalLM` | Qwen3NextMoE | `Qwen/Qwen3-Next-80B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | +| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | +| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | +| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | +| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | +| `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | +| `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ | +| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | +| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | +| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | +| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | +| `LongcatFlashForCausalLM` | LongCat-Flash | `meituan-longcat/LongCat-Flash-Chat`, `meituan-longcat/LongCat-Flash-Chat-FP8` | ✅︎ | ✅︎ | Some models are supported only via the [Transformers backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it! -| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | -|--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `SmolLM3ForCausalLM` | SmolLM3 | `HuggingFaceTB/SmolLM3-3B` | ✅︎ | ✅︎ | ✅︎ | +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | +|--------------|--------|-------------------|----------------------|---------------------------| +| `SmolLM3ForCausalLM` | SmolLM3 | `HuggingFaceTB/SmolLM3-3B` | ✅︎ | ✅︎ | !!! note Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. @@ -452,21 +445,21 @@ See [this page](./pooling_models.md) for more information on how to use pooling These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) API. -| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | -|--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | ✅︎ | -| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Gemma3TextModel`<sup>C</sup> | Gemma 3-based | `google/embeddinggemma-300m`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ | -| `GteModel`<sup>C</sup> | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | ✅︎ | -| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | ✅︎ | -| `ModernBertModel`<sup>C</sup> | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | ✅︎ | -| `NomicBertModel`<sup>C</sup> | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | ✅︎ | -| `LlamaModel`<sup>C</sup>, `LlamaForCausalLM`<sup>C</sup>, `MistralModel`<sup>C</sup>, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2Model`<sup>C</sup>, `Qwen2ForCausalLM`<sup>C</sup> | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen3Model`<sup>C</sup>, `Qwen3ForCausalLM`<sup>C</sup> | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | ✅︎ | -| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* | +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | +|--------------|--------|-------------------|----------------------|---------------------------| +| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | +| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | +| `Gemma3TextModel`<sup>C</sup> | Gemma 3-based | `google/embeddinggemma-300m`, etc. | ✅︎ | ✅︎ | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | +| `GteModel`<sup>C</sup> | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | +| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | +| `ModernBertModel`<sup>C</sup> | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | +| `NomicBertModel`<sup>C</sup> | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | +| `LlamaModel`<sup>C</sup>, `LlamaForCausalLM`<sup>C</sup>, `MistralModel`<sup>C</sup>, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | +| `Qwen2Model`<sup>C</sup>, `Qwen2ForCausalLM`<sup>C</sup> | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | +| `Qwen3Model`<sup>C</sup>, `Qwen3ForCausalLM`<sup>C</sup> | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | +| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | +| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | <sup>C</sup> Automatically converted into an embedding model via `--convert embed`. ([details](./pooling_models.md#model-conversion)) \* Feature support is the same as that of the original model. @@ -493,11 +486,11 @@ of the whole prompt are extracted from the normalized hidden state corresponding These models primarily support the [`LLM.classify`](./pooling_models.md#llmclassify) API. -| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | -|--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ | -| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* | +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | +|--------------|--------|-------------------|----------------------|---------------------------| +| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | +| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | +| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | <sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion)) \* Feature support is the same as that of the original model. @@ -510,16 +503,16 @@ If your model is not in the above list, we will try to automatically convert the Cross-encoder and reranker models are a subset of classification models that accept two prompts as input. These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API. -| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | -|--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ | -| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | -| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | | ✅︎ | -| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | -| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ | -| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | ✅︎ | -| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* | +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | +|--------------|--------|-------------------|----------------------|---------------------------| +| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | +| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | +| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | | +| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | +| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | +| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | +| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | +| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | <sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion)) \* Feature support is the same as that of the original model. @@ -542,7 +535,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A ``` !!! note - Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: <gh-file:examples/offline_inference/pooling/qwen3_reranker.py>. + Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: [examples/offline_inference/pooling/qwen3_reranker.py](../../examples/offline_inference/pooling/qwen3_reranker.py). ```bash vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' @@ -552,13 +545,13 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A These models primarily support the [`LLM.reward`](./pooling_models.md#llmreward) API. -| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | -|--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `LlamaForCausalLM`<sup>C</sup> | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* | +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | +|--------------|--------|-------------------|----------------------|---------------------------| +| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | +| `LlamaForCausalLM`<sup>C</sup> | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | +| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | +| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ | +| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | <sup>C</sup> Automatically converted into a reward model via `--convert reward`. ([details](./pooling_models.md#model-conversion)) \* Feature support is the same as that of the original model. @@ -574,15 +567,13 @@ If your model is not in the above list, we will try to automatically convert the These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode) API. -| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | -|--------------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------| -| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | ✅︎ | -| `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | | ✅︎ | +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | +|--------------|--------|-------------------|-----------------------------|-----------------------------------------| +| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | +| `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | | !!! note - Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner.py>. - -[](){ #supported-mm-models } + Named Entity Recognition (NER) usage, please refer to [examples/offline_inference/pooling/ner.py](../../examples/offline_inference/pooling/ner.py), [examples/online_serving/pooling/ner_client.py](../../examples/online_serving/pooling/ner_client.py). ## List of Multimodal Language Models @@ -603,29 +594,6 @@ On the other hand, modalities separated by `/` are mutually exclusive. See [this page](../features/multimodal_inputs.md) on how to pass multi-modal inputs to the model. -!!! important - **To enable multiple multi-modal items per text prompt in vLLM V0**, you have to set `limit_mm_per_prompt` (offline inference) - or `--limit-mm-per-prompt` (online serving). For example, to enable passing up to 4 images per text prompt: - - Offline inference: - - ```python - from vllm import LLM - - llm = LLM( - model="Qwen/Qwen2-VL-7B-Instruct", - limit_mm_per_prompt={"image": 4}, - ) - ``` - - Online serving: - - ```bash - vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt '{"image":4}' - ``` - - **This is no longer required if you are using vLLM V1.** - !!! tip For hybrid-only models such as Llama-4, Step3 and Mistral-3, a text-only mode can be enabled by setting all supported multimodal modalities to 0 (e.g, `--limit-mm-per-prompt '{"image":0}`) so that their multimodal modules will not be loaded to free up more GPU memory for KV cache. @@ -662,69 +630,73 @@ See [this page](generative_models.md) for more information on how to use generat These models primarily accept the [`LLM.generate`](./generative_models.md#llmgenerate) API. Chat/Instruct models additionally support the [`LLM.chat`](./generative_models.md#llmchat) API. -| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | -|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------| -| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | | ✅︎ | -| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | | ✅︎ | ✅︎ | -| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ | ✅︎ | -| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ | -| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ | -| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ | -| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ | -| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ | -| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | -| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | -| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | -| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ | -| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ | -| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + I<sup>E+</sup> + V<sup>E+</sup> | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | ✅︎ | -| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ | ✅︎ | -| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | ✅︎ | -| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ | -| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ | -| `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ | -| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + I<sup>E+</sup> | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ | ✅︎ | -| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ | -| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ | -| `MiDashengLMModel` | MiDashengLM | T + A<sup>+</sup> | `mispeech/midashenglm-7b` | | ✅︎ | ✅︎ | -| `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, `openbmb/MiniCPM-V-4_5`, etc. | ✅︎ | | ✅︎ | -| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ | -| `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ | -| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ | -| `Ovis2_5` | Ovis2.5 | T + I<sup>+</sup> + V | `AIDC-AI/Ovis2.5-9B`, etc. | | | ✅︎ | -| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ | -| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ | -| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ | ✅︎ | -| `PixtralForConditionalGeneration` | Mistral 3 (Mistral format), Pixtral (Mistral format) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Pixtral-12B-2409`, etc. | | ✅︎ | ✅︎ | -| `QwenVLForConditionalGeneration`<sup>^</sup> | Qwen-VL | T + I<sup>E+</sup> | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | ✅︎ | -| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-3B`, `Qwen/Qwen2.5-Omni-7B` | ✅︎ | ✅︎ | ✅︎ | -| `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | ✅︎ | -| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | -| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | -| `Step3VLForConditionalGeneration` | Step3-VL | T + I<sup>+</sup> | `stepfun-ai/step3` | | ✅︎ | ✅︎ | -| `TarsierForConditionalGeneration` | Tarsier | T + I<sup>E+</sup> | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ | ✅︎ | -| `Tarsier2ForConditionalGeneration`<sup>^</sup> | Tarsier2 | T + I<sup>E+</sup> + V<sup>E+</sup> | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ | +| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | +|--------------|--------|--------|-------------------|----------------------|---------------------------| +| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | | +| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | | ✅︎ | +| `BeeForConditionalGeneration` | Bee-8B | T + I<sup>E+</sup> | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ | +| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ | +| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | +| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | +| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | +| `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I<sup>+</sup> | `deepseek-ai/DeepSeek-OCR`, etc. | | ✅︎ | +| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | +| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | +| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | +| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | +| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | +| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | +| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | +| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | +| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | +| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | +| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ | +| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | +| `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + I<sup>E+</sup> + V<sup>E+</sup> | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | +| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | +| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ | +| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | +| `LightOnOCRForConditionalGeneration` | LightOnOCR-1B | T + I<sup>+</sup> | `lightonai/LightOnOCR-1B`, etc | ✅︎ | ✅︎ | +| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | +| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | +| `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ | +| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + I<sup>E+</sup> | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ | +| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | +| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | +| `MiDashengLMModel` | MiDashengLM | T + A<sup>+</sup> | `mispeech/midashenglm-7b` | | ✅︎ | +| `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | +| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, `openbmb/MiniCPM-V-4_5`, etc. | ✅︎ | | +| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | +| `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | +| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | +| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | +| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | +| `Ovis2_5` | Ovis2.5 | T + I<sup>+</sup> + V | `AIDC-AI/Ovis2.5-9B`, etc. | | | +| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | +| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | +| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | +| `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ | +| `PixtralForConditionalGeneration` | Mistral 3 (Mistral format), Pixtral (Mistral format) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Pixtral-12B-2409`, etc. | | ✅︎ | +| `QwenVLForConditionalGeneration`<sup>^</sup> | Qwen-VL | T + I<sup>E+</sup> | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ | +| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | +| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | +| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | +| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-3B`, `Qwen/Qwen2.5-Omni-7B` | ✅︎ | ✅︎ | +| `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ | +| `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | +| `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, `Qwen/Qwen3-Omni-30B-A3B-Thinking` | ✅︎ | ✅︎ | +| `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | +| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | +| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | +| `Step3VLForConditionalGeneration` | Step3-VL | T + I<sup>+</sup> | `stepfun-ai/step3` | | ✅︎ | +| `TarsierForConditionalGeneration` | Tarsier | T + I<sup>E+</sup> | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ | +| `Tarsier2ForConditionalGeneration`<sup>^</sup> | Tarsier2 | T + I<sup>E+</sup> + V<sup>E+</sup> | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ | Some models are supported only via the [Transformers backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it! -| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | -|--------------|--------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------| -| `Emu3ForConditionalGeneration` | Emu3 | T + I | `BAAI/Emu3-Chat-hf` | ✅︎ | ✅︎ | ✅︎ | +| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | +|--------------|--------|--------|-------------------|-----------------------------|-----------------------------------------| +| `Emu3ForConditionalGeneration` | Emu3 | T + I | `BAAI/Emu3-Chat-hf` | ✅︎ | ✅︎ | <sup>^</sup> You need to set the architecture name via `--hf-overrides` to match the one in vLLM.     • For example, to use DeepSeek-VL2 series models: @@ -797,24 +769,23 @@ Some models are supported only via the [Transformers backend](#transformers). Th !!! note The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now. - For more details, please see: <gh-pr:4087#issuecomment-2250397630> + For more details, please see: <https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630> !!! warning Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1. !!! note - For Qwen2.5-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) - is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1. + For Qwen2.5-Omni and Qwen3-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) is currently work in progress and not yet supported. #### Transcription Speech2Text models trained specifically for Automatic Speech Recognition. -| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | -|--------------|--------|-------------------|----------------------|---------------------------|---------------------| -| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | ✅︎ | -| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | +|--------------|--------|-------------------|----------------------|---------------------------| +| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | +| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ | +| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ### Pooling Models @@ -829,12 +800,13 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A The following table lists those that are tested in vLLM. -| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | -|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------| -| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | | ✅︎ | -| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ | ✅︎ | -| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ | ✅︎ | -| `*ForConditionalGeneration`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | \* | N/A | \* | \* | \* | +| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | +|--------------|--------|--------|-------------------|----------------------|---------------------------| +| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | | +| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ | +| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ | +| `SiglipModel` | SigLIP | T / I | `google/siglip-base-patch16-224` | | | +| `*ForConditionalGeneration`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | \* | N/A | \* | \* | <sup>C</sup> Automatically converted into an embedding model via `--convert embed`. ([details](./pooling_models.md#model-conversion)) \* Feature support is the same as that of the original model. @@ -846,9 +818,9 @@ The following table lists those that are tested in vLLM. Cross-encoder and reranker models are a subset of classification models that accept two prompts as input. These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API. -| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | -|-------------------------------------|--------------------|----------|--------------------------|------------------------|-----------------------------|-----------------------| -| `JinaVLForSequenceClassification` | JinaVL-based | T + I<sup>E+</sup> | `jinaai/jina-reranker-m0`, etc. | | | ✅︎ | +| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | +|--------------|--------|--------|-------------------|----------------------|---------------------------| +| `JinaVLForSequenceClassification` | JinaVL-based | T + I<sup>E+</sup> | `jinaai/jina-reranker-m0`, etc. | ✅︎ | ✅︎ | <sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion)) \* Feature support is the same as that of the original model. @@ -878,5 +850,5 @@ We have the following levels of testing for models: 1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to [models tests](https://github.com/vllm-project/vllm/blob/main/tests/models) for the models that have passed this test. 2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test. -3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to [functionality tests](gh-dir:tests) and [examples](gh-dir:examples) for the models that have passed this test. +3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to [functionality tests](../../tests) and [examples](../../examples) for the models that have passed this test. 4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. diff --git a/docs/serving/context_parallel_deployment.md b/docs/serving/context_parallel_deployment.md new file mode 100644 index 000000000000..dacdf312ee55 --- /dev/null +++ b/docs/serving/context_parallel_deployment.md @@ -0,0 +1,47 @@ +# Context Parallel Deployment + +Context parallel mainly solves the problem of serving long context requests. As prefill and decode present quite different characteristics and have quite different SLO (service level objectives), we need to implement context parallel separately for them. The major considerations are: + +- For long context prefill, we need to control the TTFT (time to first token) by amortizing the computation time of the prefill across query tokens. +- For long context decode, we need more space for KV cache to increase the batchsize (and hence the throughput). + +## Prefill Context Parallel + +During prefill, for a long request with `T` new tokens, we need to compute query/key/value tensors for these new tokens. Say we have `N` GPUs, we can split the request into `N` chunks, and each GPU computes one chunk of the query/key/value tensors. + +Depending on the use case, there're two possible strategies: + +1. Partial query, full key/value: If the request token length is moderately long (we can afford holding the full key/value tensors), and the goal is to accelerate the prefill (and amortize the computation time of the prefill across query tokens), then we can gather the key/value tensors from all GPUs and let each GPU compute the attention output corresponding to the query tokens of its chunk. +2. Partial query, partial key/value: If the request token length is too long, we cannot afford holding the full key/value tensors anymore, then we can only compute one chunk of query/key/value tensors for each GPU, and use techniques like [ring-attention](http://arxiv.org/abs/2310.01889) to send/recv key/value tensors chunk by chunk. + +Both approaches are under active development. + +## Decode Context Parallel + +Due to the auto-regressive nature of decoding, every decoding step needs to compute a small amount of query tokens w.r.t. a large number of key/value tokens stored in the paged KV cache. The core of decode context parallel is how to shard the KV cache across GPUs. + +For a model with `H` kv-heads, a request with `T` tokens in the context needs to store `H * T` key/value tensors in the KV cache. + +1. If one GPU can hold them all, and the performance is good enough, then no parallelization is needed. +2. If one GPU cannot hold them all, or we want to hold more requests in the KV cache, we can first shard the KV cache along the `H` dimension, that's the plain tensor parallel sharding. It's as simple as adding `-tp <num_gpus>` to the command line. +3. Since `H` is limited (determined by the model architecture), when we continue to increase the tensor parallel size, the KV cache for each GPU will be duplicated for `tp_size / H` times. Of course, duplication is not good for efficiency. Then we need to add decode context parallel to further shard the KV cache along the `T` dimension. This is as simple as adding `-dcp <size>` to the command line. Note that `size` does not increase the number of GPUs we need to launch, but just reduces the KV cache duplication. The dcp size should lie in the range of `[1, tp_size/H]`. With larger dcp size, the KV cache duplication is reduced, but the communication overhead increases. + +Theoretically, it is possible to extend the dcp size beyond `tp_size / H` to further shard the KV cache and accelerate the decoding phase. However, since the number of query tokens is limited in decoding, it's unclear what should we do for the remaining `dcp_size - tp_size / H` GPUs for non-attention layers. For the sake of simplicity, dcp size is upper bounded by `tp_size / H`. If you want to further accelerate the decoding phase, you can consider increasing the `tp_size` first, and then increasing the dcp size. + +Note that kv cache can grow during decoding, and the sharding strategy needs to be carefully implemented. We use an interleaving strategy to shard the KV cache along the `T` dimension, so that kv cache for future tokens can be naturally sharded along the `T` dimension. This is proposed by [Chao Hong from Moonshot](https://github.com/youzhedian), and also explained in details in [this paper](http://arxiv.org/abs/2507.07120). + +Case study: + +For DeepSeek-R1, we have 1 kv-head when MLA is enabled. The typical single-node deployment with `-tp 8` causes 8x KV cache duplication. We can consider adding `-dcp 8` to reduce the KV cache duplication. + +For Kimi-K2, the architecture is similar to DeepSeek-R1, but with more parameters. When we deploy it with `-tp 16`, the KV cache duplication is 16x. We can add `-dcp 16` to completely remove the KV cache duplication, at the cost of more communication overhead. We can also add `-dcp 8` to reduce the KV cache duplication to 2x. Although it still duplicates the KV cache twice, the communication overhead is smaller since the DCP communication only happens inside one node. + +For Qwen3-235B-A22B, we have 4 kv-heads. When we deploy it with `-tp 8`, the KV cache duplication is 2x. Then we can add `-dcp 2` to remove the KV cache duplication. + +In short, for decode context parallel, try to increase `-tp` size until you get satisfactory performance, and then add `-dcp` to reduce the KV cache duplication. + +Decode context parallel is supported in vLLM, for both MLA and GQA models. Some attention backends also support the combination of decode context parallel and MTP (multi-token prediction) to further accelerate the decoding phase. + +## Technical Discussions + +The main discussions happen in the `#sig-context-parallel` channel of [vLLM Slack](https://slack.vllm.ai/). diff --git a/docs/serving/data_parallel_deployment.md b/docs/serving/data_parallel_deployment.md index 9ff9f59c54e5..eff9c5d5e4ef 100644 --- a/docs/serving/data_parallel_deployment.md +++ b/docs/serving/data_parallel_deployment.md @@ -16,7 +16,7 @@ For MoE models, when any requests are in progress in any rank, we must ensure th In all cases, it is beneficial to load-balance requests between DP ranks. For online deployments, this balancing can be optimized by taking into account the state of each DP engine - in particular its currently scheduled and waiting (queued) requests, and KV cache state. Each DP engine has an independent KV cache, and the benefit of prefix caching can be maximized by directing prompts intelligently. -This document focuses on online deployments (with the API server). DP + EP is also supported for offline usage (via the LLM class), for an example see <gh-file:examples/offline_inference/data_parallel.py>. +This document focuses on online deployments (with the API server). DP + EP is also supported for offline usage (via the LLM class), for an example see [examples/offline_inference/data_parallel.py](../../examples/offline_inference/data_parallel.py). There are two distinct modes supported for online deployments - self-contained with internal load balancing, or externally per-rank process deployment and load balancing. @@ -69,6 +69,7 @@ There are several notable differences when using Ray: - A single launch command (on any node) is needed to start all local and remote DP ranks, therefore it is more convenient compared to launching on each node - There is no need to specify `--data-parallel-address`, and the node where the command is run is used as `--data-parallel-address` - There is no need to specify `--data-parallel-rpc-port` +- When a single DP group requires multiple nodes, *e.g.* in case a single model replica needs to run on at least two nodes, make sure to set `VLLM_RAY_DP_PACK_STRATEGY="span"` in which case `--data-parallel-size-local` is ignored and will be automatically determined - Remote DP ranks will be allocated based on node resources of the Ray cluster Currently, the internal DP load balancing is done within the API server process(es) and is based on the running and waiting queues in each of the engines. This could be made more sophisticated in future by incorporating KV cache aware logic. diff --git a/docs/serving/distributed_troubleshooting.md b/docs/serving/distributed_troubleshooting.md index bd45f010ed2a..b5354a7e55d5 100644 --- a/docs/serving/distributed_troubleshooting.md +++ b/docs/serving/distributed_troubleshooting.md @@ -4,11 +4,11 @@ For general troubleshooting, see [Troubleshooting](../usage/troubleshooting.md). ## Verify inter-node GPU communication -After you start the Ray cluster, verify GPU-to-GPU communication across nodes. Proper configuration can be non-trivial. For more information, see [troubleshooting script][troubleshooting-incorrect-hardware-driver]. If you need additional environment variables for communication configuration, append them to <gh-file:examples/online_serving/run_cluster.sh>, for example `-e NCCL_SOCKET_IFNAME=eth0`. Setting environment variables during cluster creation is recommended because the variables propagate to all nodes. In contrast, setting environment variables in the shell affects only the local node. For more information, see <gh-issue:6803>. +After you start the Ray cluster, verify GPU-to-GPU communication across nodes. Proper configuration can be non-trivial. For more information, see [troubleshooting script](../usage/troubleshooting.md#incorrect-hardwaredriver). If you need additional environment variables for communication configuration, append them to [examples/online_serving/run_cluster.sh](../../examples/online_serving/run_cluster.sh), for example `-e NCCL_SOCKET_IFNAME=eth0`. Setting environment variables during cluster creation is recommended because the variables propagate to all nodes. In contrast, setting environment variables in the shell affects only the local node. For more information, see <https://github.com/vllm-project/vllm/issues/6803>. ## No available node types can fulfill resource request -The error message `Error: No available node types can fulfill resource request` can appear even when the cluster has enough GPUs. The issue often occurs when nodes have multiple IP addresses and vLLM can't select the correct one. Ensure that vLLM and Ray use the same IP address by setting `VLLM_HOST_IP` in <gh-file:examples/online_serving/run_cluster.sh> (with a different value on each node). Use `ray status` and `ray list nodes` to verify the chosen IP address. For more information, see <gh-issue:7815>. +The error message `Error: No available node types can fulfill resource request` can appear even when the cluster has enough GPUs. The issue often occurs when nodes have multiple IP addresses and vLLM can't select the correct one. Ensure that vLLM and Ray use the same IP address by setting `VLLM_HOST_IP` in [examples/online_serving/run_cluster.sh](../../examples/online_serving/run_cluster.sh) (with a different value on each node). Use `ray status` and `ray list nodes` to verify the chosen IP address. For more information, see <https://github.com/vllm-project/vllm/issues/7815>. ## Ray observability diff --git a/docs/serving/expert_parallel_deployment.md b/docs/serving/expert_parallel_deployment.md index e44a914c726d..ec07896592ba 100644 --- a/docs/serving/expert_parallel_deployment.md +++ b/docs/serving/expert_parallel_deployment.md @@ -8,19 +8,22 @@ EP is typically coupled with Data Parallelism (DP). While DP can be used indepen Before using EP, you need to install the necessary dependencies. We are actively working on making this easier in the future: -1. **Install DeepEP and pplx-kernels**: Set up host environment following vLLM's guide for EP kernels [here](gh-file:tools/ep_kernels). +1. **Install DeepEP and pplx-kernels**: Set up host environment following vLLM's guide for EP kernels [here](../../tools/ep_kernels). 2. **Install DeepGEMM library**: Follow the [official instructions](https://github.com/deepseek-ai/DeepGEMM#installation). -3. **For disaggregated serving**: Install `gdrcopy` by running the [`install_gdrcopy.sh`](gh-file:tools/install_gdrcopy.sh) script (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). +3. **For disaggregated serving**: Install `gdrcopy` by running the [`install_gdrcopy.sh`](../../tools/install_gdrcopy.sh) script (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). ### Backend Selection Guide -vLLM provides three communication backends for EP: +vLLM provides multiple communication backends for EP. Use `--all2all-backend` to select one: | Backend | Use Case | Features | Best For | |---------|----------|----------|----------| -| `pplx` | Single node | Chunked prefill support | Development, best for intra-node deployments | -| `deepep_high_throughput` | Multi-node prefill | Grouped GEMM with continuous layout | High-throughput scenarios, prefill-dominated workloads | -| `deepep_low_latency` | Multi-node decode | CUDA graph support, masked layout | Low-latency scenarios, decode-dominated workloads | +| `allgather_reducescatter` | Default backend | Standard all2all using allgather/reducescatter primitives | General purpose, works with any EP+DP configuration | +| `pplx` | Single node | Chunked prefill support, efficient intra-node communication | Single-node deployments, development | +| `deepep_high_throughput` | Multi-node prefill | Grouped GEMM with continuous layout, optimized for prefill | Prefill-dominated workloads, high-throughput scenarios | +| `deepep_low_latency` | Multi-node decode | CUDA graph support, masked layout, optimized for decode | Decode-dominated workloads, low-latency scenarios | +| `flashinfer_all2allv` | MNNVL systems | FlashInfer alltoallv kernels for multi-node NVLink | Systems with NVLink across nodes | +| `naive` | Testing/debugging | Simple broadcast-based implementation | Debugging, not recommended for production | ## Single Node Deployment @@ -47,11 +50,11 @@ The following command serves a `DeepSeek-V3-0324` model with 1-way tensor parall ```bash # Single node EP deployment with pplx backend -VLLM_ALL2ALL_BACKEND=pplx VLLM_USE_DEEP_GEMM=1 \ - vllm serve deepseek-ai/DeepSeek-V3-0324 \ - --tensor-parallel-size 1 \ # Tensor parallelism across 1 GPU +vllm serve deepseek-ai/DeepSeek-V3-0324 \ + --tensor-parallel-size 1 \ # Tensor parallelism across 1 GPU --data-parallel-size 8 \ # Data parallelism across 8 processes - --enable-expert-parallel # Enable expert parallelism + --enable-expert-parallel \ # Enable expert parallelism + --all2all-backend pplx # Use pplx communication backend ``` ## Multi-Node Deployment @@ -70,8 +73,8 @@ The following example deploys `DeepSeek-V3-0324` across 2 nodes using `deepep_lo ```bash # Node 1 (Primary - handles incoming requests) -VLLM_ALL2ALL_BACKEND=deepep_low_latency VLLM_USE_DEEP_GEMM=1 \ - vllm serve deepseek-ai/DeepSeek-V3-0324 \ +vllm serve deepseek-ai/DeepSeek-V3-0324 \ + --all2all-backend deepep_low_latency \ --tensor-parallel-size 1 \ # TP size per node --enable-expert-parallel \ # Enable EP --data-parallel-size 16 \ # Total DP size across all nodes @@ -81,8 +84,8 @@ VLLM_ALL2ALL_BACKEND=deepep_low_latency VLLM_USE_DEEP_GEMM=1 \ --api-server-count=8 # Number of API servers for load handling (scaling this out to total ranks are recommended) # Node 2 (Secondary - headless mode, no API server) -VLLM_ALL2ALL_BACKEND=deepep_low_latency VLLM_USE_DEEP_GEMM=1 \ - vllm serve deepseek-ai/DeepSeek-V3-0324 \ +vllm serve deepseek-ai/DeepSeek-V3-0324 \ + --all2all-backend deepep_low_latency \ --tensor-parallel-size 1 \ # TP size per node --enable-expert-parallel \ # Enable EP --data-parallel-size 16 \ # Total DP size across all nodes @@ -169,11 +172,12 @@ Single node deployment with EPLB enabled: ```bash # Single node with EPLB load balancing -VLLM_ALL2ALL_BACKEND=pplx VLLM_USE_DEEP_GEMM=1 vllm serve deepseek-ai/DeepSeek-V3-0324 \ - --tensor-parallel-size 1 \ # Tensor parallelism - --data-parallel-size 8 \ # Data parallelism - --enable-expert-parallel \ # Enable EP - --enable-eplb \ # Enable load balancer +vllm serve deepseek-ai/DeepSeek-V3-0324 \ + --tensor-parallel-size 1 \ # Tensor parallelism + --data-parallel-size 8 \ # Data parallelism + --enable-expert-parallel \ # Enable EP + --all2all-backend pplx \ # Use pplx communication backend + --enable-eplb \ # Enable load balancer --eplb-config '{"window_size":1000,"step_interval":3000,"num_redundant_experts":2,"log_balancedness":true}' ``` @@ -191,7 +195,7 @@ For production deployments requiring strict SLA guarantees for time-to-first-tok ### Setup Steps -1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip. +1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](../../tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip. For non-cuda platform to install nixl with non-cuda UCX build, run the [install_nixl_from_source_ubuntu.py](../../tools/install_nixl_from_source_ubuntu.py) script. 2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backends":["UCX", "GDS"]}}'` @@ -239,10 +243,10 @@ try: "remote_engine_id": None, # Will be populated by vLLM "remote_block_ids": None, # Will be populated by vLLM "remote_host": None, # Will be populated by vLLM - "remote_port": None # Will be populated by vLLM + "remote_port": None, # Will be populated by vLLM } }, - extra_headers={"X-Request-Id": request_id} + extra_headers={"X-Request-Id": request_id}, ) print("-" * 50) @@ -258,7 +262,7 @@ try: extra_body={ "kv_transfer_params": prefill_response.kv_transfer_params # Pass KV cache info }, - extra_headers={"X-Request-Id": request_id} # Same request ID + extra_headers={"X-Request-Id": request_id}, # Same request ID ) print("-" * 50) diff --git a/docs/serving/integrations/langchain.md b/docs/serving/integrations/langchain.md index 47074f411ac9..192a61ea5b90 100644 --- a/docs/serving/integrations/langchain.md +++ b/docs/serving/integrations/langchain.md @@ -15,13 +15,15 @@ To run inference on a single or multiple GPUs, use `VLLM` class from `langchain` ```python from langchain_community.llms import VLLM - llm = VLLM(model="mosaicml/mpt-7b", - trust_remote_code=True, # mandatory for hf models - max_new_tokens=128, - top_k=10, - top_p=0.95, - temperature=0.8, - # tensor_parallel_size=... # for distributed inference + llm = VLLM( + model="mosaicml/mpt-7b", + trust_remote_code=True, # mandatory for hf models + max_new_tokens=128, + top_k=10, + top_p=0.95, + temperature=0.8, + # for distributed inference + # tensor_parallel_size=..., ) print(llm("What is the capital of France ?")) diff --git a/docs/serving/offline_inference.md b/docs/serving/offline_inference.md index ddda47690002..b3d211871821 100644 --- a/docs/serving/offline_inference.md +++ b/docs/serving/offline_inference.md @@ -19,7 +19,7 @@ The available APIs depend on the model type: - [Pooling models](../models/pooling_models.md) output their hidden states directly. !!! info - [API Reference][offline-inference-api] + [API Reference](../api/README.md#offline-inference) ## Ray Data LLM API diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index fe0e1e3df378..1414718a697d 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -24,8 +24,8 @@ To call the server, in your preferred text editor, create a script that uses an completion = client.chat.completions.create( model="NousResearch/Meta-Llama-3-8B-Instruct", messages=[ - {"role": "user", "content": "Hello!"} - ] + {"role": "user", "content": "Hello!"}, + ], ) print(completion.choices[0].message) @@ -44,37 +44,35 @@ To call the server, in your preferred text editor, create a script that uses an We currently support the following OpenAI APIs: -- [Completions API][completions-api] (`/v1/completions`) +- [Completions API](#completions-api) (`/v1/completions`) - Only applicable to [text generation models](../models/generative_models.md). - *Note: `suffix` parameter is not supported.* -- [Chat Completions API][chat-api] (`/v1/chat/completions`) - - Only applicable to [text generation models](../models/generative_models.md) with a [chat template][chat-template]. +- [Chat Completions API](#chat-api) (`/v1/chat/completions`) + - Only applicable to [text generation models](../models/generative_models.md) with a [chat template](../serving/openai_compatible_server.md#chat-template). - *Note: `parallel_tool_calls` and `user` parameters are ignored.* -- [Embeddings API][embeddings-api] (`/v1/embeddings`) +- [Embeddings API](#embeddings-api) (`/v1/embeddings`) - Only applicable to [embedding models](../models/pooling_models.md). -- [Transcriptions API][transcriptions-api] (`/v1/audio/transcriptions`) +- [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`) - Only applicable to [Automatic Speech Recognition (ASR) models](../models/supported_models.md#transcription). -- [Translation API][translations-api] (`/v1/audio/translations`) +- [Translation API](#translations-api) (`/v1/audio/translations`) - Only applicable to [Automatic Speech Recognition (ASR) models](../models/supported_models.md#transcription). In addition, we have the following custom APIs: -- [Tokenizer API][tokenizer-api] (`/tokenize`, `/detokenize`) +- [Tokenizer API](#tokenizer-api) (`/tokenize`, `/detokenize`) - Applicable to any model with a tokenizer. -- [Pooling API][pooling-api] (`/pooling`) +- [Pooling API](#pooling-api) (`/pooling`) - Applicable to all [pooling models](../models/pooling_models.md). -- [Classification API][classification-api] (`/classify`) +- [Classification API](#classification-api) (`/classify`) - Only applicable to [classification models](../models/pooling_models.md). -- [Score API][score-api] (`/score`) +- [Score API](#score-api) (`/score`) - Applicable to [embedding models and cross-encoder models](../models/pooling_models.md). -- [Re-rank API][rerank-api] (`/rerank`, `/v1/rerank`, `/v2/rerank`) +- [Re-rank API](#re-rank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`) - Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/) - Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank) - Jina and Cohere's APIs are very similar; Jina's includes extra information in the rerank endpoint's response. - Only applicable to [cross-encoder models](../models/pooling_models.md). -[](){ #chat-template } - ## Chat Template In order for the language model to support chat protocol, vLLM requires the model to include @@ -92,7 +90,7 @@ and all chat requests will error. vllm serve <model> --chat-template ./path-to-chat-template.jinja ``` -vLLM community provides a set of chat templates for popular models. You can find them under the <gh-dir:examples> directory. +vLLM community provides a set of chat templates for popular models. You can find them under the [examples](../../examples) directory. With the inclusion of multi-modal chat APIs, the OpenAI spec now accepts chat messages in a new format which specifies both a `type` and a `text` field. An example is provided below: @@ -101,8 +99,13 @@ both a `type` and a `text` field. An example is provided below: completion = client.chat.completions.create( model="NousResearch/Meta-Llama-3-8B-Instruct", messages=[ - {"role": "user", "content": [{"type": "text", "text": "Classify this sentiment: vLLM is wonderful!"}]} - ] + { + "role": "user", + "content": [ + {"type": "text", "text": "Classify this sentiment: vLLM is wonderful!"}, + ], + }, + ], ) ``` @@ -130,11 +133,11 @@ Or directly merge them into the JSON payload if you are using HTTP call directly completion = client.chat.completions.create( model="NousResearch/Meta-Llama-3-8B-Instruct", messages=[ - {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} + {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}, ], extra_body={ - "structured_outputs": {"choice": ["positive", "negative"]} - } + "structured_outputs": {"choice": ["positive", "negative"]}, + }, ) ``` @@ -149,11 +152,11 @@ with `--enable-request-id-headers`. completion = client.chat.completions.create( model="NousResearch/Meta-Llama-3-8B-Instruct", messages=[ - {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} + {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}, ], extra_headers={ "x-request-id": "sentiment-classification-00001", - } + }, ) print(completion._request_id) @@ -162,25 +165,23 @@ with `--enable-request-id-headers`. prompt="A robot may not injure a human being", extra_headers={ "x-request-id": "completion-test", - } + }, ) print(completion._request_id) ``` ## API Reference -[](){ #completions-api } - ### Completions API Our Completions API is compatible with [OpenAI's Completions API](https://platform.openai.com/docs/api-reference/completions); you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it. -Code example: <gh-file:examples/online_serving/openai_completion_client.py> +Code example: [examples/online_serving/openai_completion_client.py](../../examples/online_serving/openai_completion_client.py) #### Extra parameters -The following [sampling parameters][sampling-params] are supported. +The following [sampling parameters](../api/README.md#inference-parameters) are supported. ??? code @@ -196,8 +197,6 @@ The following extra parameters are supported: --8<-- "vllm/entrypoints/openai/protocol.py:completion-extra-params" ``` -[](){ #chat-api } - ### Chat API Our Chat API is compatible with [OpenAI's Chat Completions API](https://platform.openai.com/docs/api-reference/chat); @@ -209,11 +208,11 @@ see our [Multimodal Inputs](../features/multimodal_inputs.md) guide for more inf - *Note: `image_url.detail` parameter is not supported.* -Code example: <gh-file:examples/online_serving/openai_chat_completion_client.py> +Code example: [examples/online_serving/openai_chat_completion_client.py](../../examples/online_serving/openai_chat_completion_client.py) #### Extra parameters -The following [sampling parameters][sampling-params] are supported. +The following [sampling parameters](../api/README.md#inference-parameters) are supported. ??? code @@ -229,16 +228,14 @@ The following extra parameters are supported: --8<-- "vllm/entrypoints/openai/protocol.py:chat-completion-extra-params" ``` -[](){ #embeddings-api } - ### Embeddings API Our Embeddings API is compatible with [OpenAI's Embeddings API](https://platform.openai.com/docs/api-reference/embeddings); you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it. -Code example: <gh-file:examples/online_serving/pooling/openai_embedding_client.py> +Code example: [examples/online_serving/pooling/openai_embedding_client.py](../../examples/online_serving/pooling/openai_embedding_client.py) -If the model has a [chat template][chat-template], you can replace `inputs` with a list of `messages` (same schema as [Chat API][chat-api]) +If the model has a [chat template](../serving/openai_compatible_server.md#chat-template), you can replace `inputs` with a list of `messages` (same schema as [Chat API](#chat-api)) which will be treated as a single prompt to the model. Here is a convenience function for calling the API while retaining OpenAI's type annotations: ??? code @@ -284,7 +281,7 @@ and passing a list of `messages` in the request. Refer to the examples below for to run this model in embedding mode instead of text generation mode. The custom chat template is completely different from the original one for this model, - and can be found here: <gh-file:examples/template_vlm2vec_phi3v.jinja> + and can be found here: [examples/template_vlm2vec_phi3v.jinja](../../examples/template_vlm2vec_phi3v.jinja) Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level `requests` library: @@ -331,13 +328,13 @@ and passing a list of `messages` in the request. Refer to the examples below for Like with VLM2Vec, we have to explicitly pass `--runner pooling`. Additionally, `MrLight/dse-qwen2-2b-mrl-v1` requires an EOS token for embeddings, which is handled - by a custom chat template: <gh-file:examples/template_dse_qwen2_vl.jinja> + by a custom chat template: [examples/template_dse_qwen2_vl.jinja](../../examples/template_dse_qwen2_vl.jinja) !!! important `MrLight/dse-qwen2-2b-mrl-v1` requires a placeholder image of the minimum image size for text query embeddings. See the full code example below for details. -Full example: <gh-file:examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py> +Full example: [examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py](../../examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py) #### Extra parameters @@ -364,8 +361,6 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s --8<-- "vllm/entrypoints/openai/protocol.py:chat-embedding-extra-params" ``` -[](){ #transcriptions-api } - ### Transcriptions API Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription); @@ -374,7 +369,7 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai !!! note To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`. -Code example: <gh-file:examples/online_serving/openai_transcription_client.py> +Code example: [examples/online_serving/openai_transcription_client.py](../../examples/online_serving/openai_transcription_client.py) #### API Enforced Limits @@ -403,7 +398,7 @@ The Transcriptions API supports uploading audio files in various formats includi model="openai/whisper-large-v3-turbo", file=audio_file, language="en", - response_format="verbose_json" + response_format="verbose_json", ) print(transcription.text) @@ -463,7 +458,7 @@ For `verbose_json` response format: #### Extra Parameters -The following [sampling parameters][sampling-params] are supported. +The following [sampling parameters](../api/README.md#inference-parameters) are supported. ??? code @@ -479,8 +474,6 @@ The following extra parameters are supported: --8<-- "vllm/entrypoints/openai/protocol.py:transcription-extra-params" ``` -[](){ #translations-api } - ### Translations API Our Translation API is compatible with [OpenAI's Translations API](https://platform.openai.com/docs/api-reference/audio/createTranslation); @@ -491,11 +484,11 @@ Please mind that the popular `openai/whisper-large-v3-turbo` model does not supp !!! note To use the Translation API, please install with extra audio dependencies using `pip install vllm[audio]`. -Code example: <gh-file:examples/online_serving/openai_translation_client.py> +Code example: [examples/online_serving/openai_translation_client.py](../../examples/online_serving/openai_translation_client.py) #### Extra Parameters -The following [sampling parameters][sampling-params] are supported. +The following [sampling parameters](../api/README.md#inference-parameters) are supported. ```python --8<-- "vllm/entrypoints/openai/protocol.py:translation-sampling-params" @@ -507,8 +500,6 @@ The following extra parameters are supported: --8<-- "vllm/entrypoints/openai/protocol.py:translation-extra-params" ``` -[](){ #tokenizer-api } - ### Tokenizer API Our Tokenizer API is a simple wrapper over [HuggingFace-style tokenizers](https://huggingface.co/docs/transformers/en/main_classes/tokenizer). @@ -517,17 +508,13 @@ It consists of two endpoints: - `/tokenize` corresponds to calling `tokenizer.encode()`. - `/detokenize` corresponds to calling `tokenizer.decode()`. -[](){ #pooling-api } - ### Pooling API Our Pooling API encodes input prompts using a [pooling model](../models/pooling_models.md) and returns the corresponding hidden states. -The input format is the same as [Embeddings API][embeddings-api], but the output data can contain an arbitrary nested list, not just a 1-D list of floats. - -Code example: <gh-file:examples/online_serving/pooling/openai_pooling_client.py> +The input format is the same as [Embeddings API](#embeddings-api), but the output data can contain an arbitrary nested list, not just a 1-D list of floats. -[](){ #classification-api } +Code example: [examples/online_serving/pooling/openai_pooling_client.py](../../examples/online_serving/pooling/openai_pooling_client.py) ### Classification API @@ -535,7 +522,7 @@ Our Classification API directly supports Hugging Face sequence-classification mo We automatically wrap any other transformer via `as_seq_cls_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities. -Code example: <gh-file:examples/online_serving/pooling/openai_classification_client.py> +Code example: [examples/online_serving/pooling/openai_classification_client.py](../../examples/online_serving/pooling/openai_classification_client.py) #### Example Requests @@ -644,8 +631,6 @@ The following extra parameters are supported: --8<-- "vllm/entrypoints/openai/protocol.py:classification-extra-params" ``` -[](){ #score-api } - ### Score API Our Score API can apply a cross-encoder model or an embedding model to predict scores for sentence or multimodal pairs. When using an embedding model the score corresponds to the cosine similarity between each embedding pair. @@ -653,7 +638,7 @@ Usually, the score for a sentence pair refers to the similarity between two sent You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). -Code example: <gh-file:examples/online_serving/openai_cross_encoder_score.py> +Code example: [examples/online_serving/openai_cross_encoder_score.py](../../examples/online_serving/openai_cross_encoder_score.py) #### Single inference @@ -812,29 +797,29 @@ You can pass multi-modal inputs to scoring models by passing `content` including "model": "jinaai/jina-reranker-m0", "text_1": "slm markdown", "text_2": { - "content": [ - { - "type": "image_url", - "image_url": { - "url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png" - }, - }, - { - "type": "image_url", - "image_url": { - "url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png" - }, - }, - ] - } + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png" + }, + }, + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png" + }, + }, + ], }, + }, ) response.raise_for_status() response_json = response.json() print("Scoring output:", response_json["data"][0]["score"]) print("Scoring output:", response_json["data"][1]["score"]) ``` -Full example: <gh-file:examples/online_serving/openai_cross_encoder_score_for_multimodal.py> +Full example: [examples/online_serving/openai_cross_encoder_score_for_multimodal.py](../../examples/online_serving/openai_cross_encoder_score_for_multimodal.py) #### Extra parameters @@ -851,8 +836,6 @@ The following extra parameters are supported: --8<-- "vllm/entrypoints/openai/protocol.py:score-extra-params" ``` -[](){ #rerank-api } - ### Re-rank API Our Re-rank API can apply an embedding model or a cross-encoder model to predict relevant scores between a single query, and @@ -866,7 +849,7 @@ endpoints are compatible with both [Jina AI's re-rank API interface](https://jin [Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with popular open-source tools. -Code example: <gh-file:examples/online_serving/pooling/jinaai_rerank_client.py> +Code example: [examples/online_serving/pooling/jinaai_rerank_client.py](../../examples/online_serving/pooling/jinaai_rerank_client.py) #### Example Request @@ -944,6 +927,6 @@ Key capabilities: - Scales from a single GPU to a multi-node cluster without code changes. - Provides observability and autoscaling policies through Ray dashboards and metrics. -The following example shows how to deploy a large model like DeepSeek R1 with Ray Serve LLM: <gh-file:examples/online_serving/ray_serve_deepseek.py>. +The following example shows how to deploy a large model like DeepSeek R1 with Ray Serve LLM: [examples/online_serving/ray_serve_deepseek.py](../../examples/online_serving/ray_serve_deepseek.py). Learn more about Ray Serve LLM with the official [Ray Serve LLM documentation](https://docs.ray.io/en/latest/serve/llm/serving-llms.html). diff --git a/docs/serving/parallelism_scaling.md b/docs/serving/parallelism_scaling.md index cef1127fc5c1..14cd3b057791 100644 --- a/docs/serving/parallelism_scaling.md +++ b/docs/serving/parallelism_scaling.md @@ -72,7 +72,7 @@ For details, see the [Ray documentation](https://docs.ray.io/en/latest/index.htm ### Ray cluster setup with containers -The helper script <gh-file:examples/online_serving/run_cluster.sh> starts containers across nodes and initializes Ray. By default, the script runs Docker without administrative privileges, which prevents access to the GPU performance counters when profiling or tracing. To enable admin privileges, add the `--cap-add=CAP_SYS_ADMIN` flag to the Docker command. +The helper script [examples/online_serving/run_cluster.sh](../../examples/online_serving/run_cluster.sh) starts containers across nodes and initializes Ray. By default, the script runs Docker without administrative privileges, which prevents access to the GPU performance counters when profiling or tracing. To enable admin privileges, add the `--cap-add=CAP_SYS_ADMIN` flag to the Docker command. Choose one node as the head node and run: @@ -132,7 +132,7 @@ vllm serve /path/to/the/model/in/the/container \ Efficient tensor parallelism requires fast inter-node communication, preferably through high-speed network adapters such as InfiniBand. To set up the cluster to use InfiniBand, append additional arguments like `--privileged -e NCCL_IB_HCA=mlx5` to the -<gh-file:examples/online_serving/run_cluster.sh> helper script. +[examples/online_serving/run_cluster.sh](../../examples/online_serving/run_cluster.sh) helper script. Contact your system administrator for more information about the required flags. ## Enabling GPUDirect RDMA diff --git a/docs/training/rlhf.md b/docs/training/rlhf.md index b207c9ed373b..0b7e384dc8d6 100644 --- a/docs/training/rlhf.md +++ b/docs/training/rlhf.md @@ -5,6 +5,7 @@ Reinforcement Learning from Human Feedback (RLHF) is a technique that fine-tunes The following open-source RL libraries use vLLM for fast rollouts (sorted alphabetically and non-exhaustive): - [Cosmos-RL](https://github.com/nvidia-cosmos/cosmos-rl) +- [ms-swift](https://github.com/modelscope/ms-swift/tree/main) - [NeMo-RL](https://github.com/NVIDIA-NeMo/RL) - [Open Instruct](https://github.com/allenai/open-instruct) - [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF) diff --git a/docs/usage/reproducibility.md b/docs/usage/reproducibility.md index a494dcf19191..d8a1943209c1 100644 --- a/docs/usage/reproducibility.md +++ b/docs/usage/reproducibility.md @@ -6,7 +6,7 @@ reproducible results: - For V1: Turn off multiprocessing to make the scheduling deterministic by setting `VLLM_ENABLE_V1_MULTIPROCESSING=0`. - For V0: Set the global seed (see below). -Example: <gh-file:examples/offline_inference/reproducibility.py> +Example: [examples/offline_inference/reproducibility.py](../../examples/offline_inference/reproducibility.py) !!! warning @@ -39,7 +39,7 @@ In V1, the `seed` parameter defaults to `0` which sets the random state for each It is impossible to un-specify a seed for V1 because different workers need to sample the same outputs for workflows such as speculative decoding. - For more information, see: <gh-pr:17929> + For more information, see: <https://github.com/vllm-project/vllm/pull/17929> ### Locality of random state diff --git a/docs/usage/troubleshooting.md b/docs/usage/troubleshooting.md index 6e700d1faaa9..94e801376e53 100644 --- a/docs/usage/troubleshooting.md +++ b/docs/usage/troubleshooting.md @@ -24,7 +24,7 @@ If the model is too large to fit in a single GPU, you will get an out-of-memory ## Generation quality changed -In v0.8.0, the source of default sampling parameters was changed in <gh-pr:12622>. Prior to v0.8.0, the default sampling parameters came from vLLM's set of neutral defaults. From v0.8.0 onwards, the default sampling parameters come from the `generation_config.json` provided by the model creator. +In v0.8.0, the source of default sampling parameters was changed in <https://github.com/vllm-project/vllm/pull/12622>. Prior to v0.8.0, the default sampling parameters came from vLLM's set of neutral defaults. From v0.8.0 onwards, the default sampling parameters come from the `generation_config.json` provided by the model creator. In most cases, this should lead to higher quality responses, because the model creator is likely to know which sampling parameters are best for their model. However, in some cases the defaults provided by the model creator can lead to degraded performance. @@ -38,7 +38,7 @@ If other strategies don't solve the problem, it's likely that the vLLM instance - `export VLLM_LOG_STATS_INTERVAL=1.` to get log statistics more frequently for tracking running queue, waiting queue and cache hit states. - `export CUDA_LAUNCH_BLOCKING=1` to identify which CUDA kernel is causing the problem. - `export NCCL_DEBUG=TRACE` to turn on more logging for NCCL. -- `export VLLM_TRACE_FUNCTION=1` to record all function calls for inspection in the log files to tell which function crashes or hangs. Do not use this flag unless absolutely needed for debugging, it will cause significant delays in startup time. +- `export VLLM_TRACE_FUNCTION=1` to record all function calls for inspection in the log files to tell which function crashes or hangs. (WARNING: This flag will slow down the token generation by **over 100x**. Do not use unless absolutely needed.) ## Breakpoints @@ -80,8 +80,6 @@ You might also need to set `export NCCL_SOCKET_IFNAME=<your_network_interface>` If vLLM crashes and the error trace captures it somewhere around `self.graph.replay()` in `vllm/worker/model_runner.py`, it is a CUDA error inside CUDAGraph. To identify the particular CUDA operation that causes the error, you can add `--enforce-eager` to the command line, or `enforce_eager=True` to the [LLM][vllm.LLM] class to disable the CUDAGraph optimization and isolate the exact CUDA operation that causes the error. -[](){ #troubleshooting-incorrect-hardware-driver } - ## Incorrect hardware/driver If GPU/CPU communication cannot be established, you can use the following Python script and follow the instructions below to confirm whether the GPU/CPU communication is working correctly. @@ -178,8 +176,6 @@ If the test script hangs or crashes, usually it means the hardware/drivers are b Adjust `--nproc-per-node`, `--nnodes`, and `--node-rank` according to your setup, being sure to execute different commands (with different `--node-rank`) on different nodes. -[](){ #troubleshooting-python-multiprocessing } - ## Python multiprocessing ### `RuntimeError` Exception @@ -238,7 +234,7 @@ if __name__ == '__main__': ## `torch.compile` Error -vLLM heavily depends on `torch.compile` to optimize the model for better performance, which introduces the dependency on the `torch.compile` functionality and the `triton` library. By default, we use `torch.compile` to [optimize some functions](gh-pr:10406) in the model. Before running vLLM, you can check if `torch.compile` is working as expected by running the following script: +vLLM heavily depends on `torch.compile` to optimize the model for better performance, which introduces the dependency on the `torch.compile` functionality and the `triton` library. By default, we use `torch.compile` to [optimize some functions](https://github.com/vllm-project/vllm/pull/10406) in the model. Before running vLLM, you can check if `torch.compile` is working as expected by running the following script: ??? code @@ -257,7 +253,7 @@ vLLM heavily depends on `torch.compile` to optimize the model for better perform print(f(x)) ``` -If it raises errors from `torch/_inductor` directory, usually it means you have a custom `triton` library that is not compatible with the version of PyTorch you are using. See <gh-issue:12219> for example. +If it raises errors from `torch/_inductor` directory, usually it means you have a custom `triton` library that is not compatible with the version of PyTorch you are using. See <https://github.com/vllm-project/vllm/issues/12219> for example. ## Model failed to be inspected @@ -297,7 +293,7 @@ But you are sure that the model is in the [list of supported models](../models/s ## Failed to infer device type -If you see an error like `RuntimeError: Failed to infer device type`, it means that vLLM failed to infer the device type of the runtime environment. You can check [the code](gh-file:vllm/platforms/__init__.py) to see how vLLM infers the device type and why it is not working as expected. After [this PR](gh-pr:14195), you can also set the environment variable `VLLM_LOGGING_LEVEL=DEBUG` to see more detailed logs to help debug the issue. +If you see an error like `RuntimeError: Failed to infer device type`, it means that vLLM failed to infer the device type of the runtime environment. You can check [the code](../../vllm/platforms/__init__.py) to see how vLLM infers the device type and why it is not working as expected. After [this PR](https://github.com/vllm-project/vllm/pull/14195), you can also set the environment variable `VLLM_LOGGING_LEVEL=DEBUG` to see more detailed logs to help debug the issue. ## NCCL error: unhandled system error during `ncclCommInitRank` @@ -322,6 +318,6 @@ This indicates vLLM failed to initialize the NCCL communicator, possibly due to ## Known Issues -- In `v0.5.2`, `v0.5.3`, and `v0.5.3.post1`, there is a bug caused by [zmq](https://github.com/zeromq/pyzmq/issues/2000) , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of `vllm` to include the [fix](gh-pr:6759). +- In `v0.5.2`, `v0.5.3`, and `v0.5.3.post1`, there is a bug caused by [zmq](https://github.com/zeromq/pyzmq/issues/2000) , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of `vllm` to include the [fix](https://github.com/vllm-project/vllm/pull/6759). - To address a memory overhead issue in older NCCL versions (see [bug](https://github.com/NVIDIA/nccl/issues/1234)), vLLM versions `>= 0.4.3, <= 0.10.1.1` would set the environment variable `NCCL_CUMEM_ENABLE=0`. External processes connecting to vLLM also needed to set this variable to prevent hangs or crashes. Since the underlying NCCL bug was fixed in NCCL 2.22.3, this override was removed in newer vLLM versions to allow for NCCL performance optimizations. - In some PCIe machines (e.g. machines without NVLink), if you see an error like `transport/shm.cc:590 NCCL WARN Cuda failure 217 'peer access is not supported between these two devices'`, it's likely caused by a driver bug. See [this issue](https://github.com/NVIDIA/nccl/issues/1838) for more details. In that case, you can try to set `NCCL_CUMEM_HOST_ENABLE=0` to disable the feature, or upgrade your driver to the latest version. diff --git a/docs/usage/usage_stats.md b/docs/usage/usage_stats.md index 4c7a7ff019e8..6225478d52d0 100644 --- a/docs/usage/usage_stats.md +++ b/docs/usage/usage_stats.md @@ -6,7 +6,7 @@ A subset of the data, after cleaning and aggregation, will be publicly released ## What data is collected? -The list of data collected by the latest version of vLLM can be found here: <gh-file:vllm/usage/usage_lib.py> +The list of data collected by the latest version of vLLM can be found here: [vllm/usage/usage_lib.py](../../vllm/usage/usage_lib.py) Here is an example as of v0.4.0: diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 340aaf54bb72..c47547cb0ea7 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -2,7 +2,7 @@ !!! announcement - We have started the process of deprecating V0. Please read [RFC #18571](gh-issue:18571) for more details. + We have started the process of deprecating V0. Please read [RFC #18571](https://github.com/vllm-project/vllm/issues/18571) for more details. V1 is now enabled by default for all supported use cases, and we will gradually enable it for every use case we plan to support. Please share any feedback on [GitHub](https://github.com/vllm-project/vllm) or in the [vLLM Slack](https://inviter.co/vllm-slack). @@ -88,20 +88,14 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the | **Mamba Models** | <nobr>🟢 (Mamba-2), 🟢 (Mamba-1)</nobr> | | **Multimodal Models** | <nobr>🟢 Functional</nobr> | -vLLM V1 currently excludes model architectures with the `SupportsV0Only` protocol. - -!!! tip - - This corresponds to the V1 column in our [list of supported models](../models/supported_models.md). - See below for the status of models that are not yet supported or have more features planned in V1. #### Embedding Models The initial basic support is now functional. -Later, we will consider using [hidden states processor](gh-issue:12249), -which is based on [global logits processor](gh-pr:13360) +Later, we will consider using [hidden states processor](https://github.com/vllm-project/vllm/issues/12249), +which is based on [global logits processor](https://github.com/vllm-project/vllm/pull/13360) to enable simultaneous generation and embedding using the same engine instance in V1. #### Mamba Models @@ -130,13 +124,13 @@ encoder and decoder (e.g., `BartForConditionalGeneration`, | **Chunked Prefill** | <nobr>🚀 Optimized</nobr> | | **LoRA** | <nobr>🚀 Optimized</nobr> | | **Logprobs Calculation** | <nobr>🟢 Functional</nobr> | -| **FP8 KV Cache** | <nobr>🟢 Functional on Hopper devices (<gh-pr:15191>)</nobr>| +| **FP8 KV Cache** | <nobr>🟢 Functional on Hopper devices (<https://github.com/vllm-project/vllm/pull/15191>)</nobr>| | **Spec Decode** | <nobr>🚀 Optimized</nobr> | -| **Prompt Logprobs with Prefix Caching** | <nobr>🟡 Planned ([RFC #13414](gh-issue:13414))</nobr>| +| **Prompt Logprobs with Prefix Caching** | <nobr>🟡 Planned ([RFC #13414](https://github.com/vllm-project/vllm/issues/13414))</nobr>| | **Structured Output Alternative Backends** | <nobr>🟢 Functional</nobr> | | **Request-level Structured Output Backend** | <nobr>🔴 Deprecated</nobr> | -| **best_of** | <nobr>🔴 Deprecated ([RFC #13361](gh-issue:13361))</nobr>| -| **Per-Request Logits Processors** | <nobr>🔴 Deprecated ([RFC #13360](gh-pr:13360))</nobr> | +| **best_of** | <nobr>🔴 Deprecated ([RFC #13361](https://github.com/vllm-project/vllm/issues/13361))</nobr>| +| **Per-Request Logits Processors** | <nobr>🔴 Deprecated ([RFC #13360](https://github.com/vllm-project/vllm/pull/13360))</nobr> | | **GPU <> CPU KV Cache Swapping** | <nobr>🔴 Deprecated</nobr> | !!! note @@ -174,11 +168,11 @@ As part of the major architectural rework in vLLM V1, several legacy features ha ##### Sampling features -- **best_of**: This feature has been deprecated due to limited usage. See details at [RFC #13361](gh-issue:13361). +- **best_of**: This feature has been deprecated due to limited usage. See details at [RFC #13361](https://github.com/vllm-project/vllm/issues/13361). - **Per-Request Logits Processors**: In V0, users could pass custom processing functions to adjust logits on a per-request basis. In vLLM V1, this feature has been deprecated. Instead, the design is moving toward supporting **global logits - processors**, a feature the team is actively working on for future releases. See details at [RFC #13360](gh-pr:13360). + processors**, a feature the team is actively working on for future releases. See details at [RFC #13360](https://github.com/vllm-project/vllm/pull/13360). ##### KV Cache features diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 65a87d2dd9e8..c4eed2037781 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -10,7 +10,7 @@ import os from dataclasses import asdict -from typing import Any, NamedTuple, Optional +from typing import Any, NamedTuple from huggingface_hub import snapshot_download from transformers import AutoTokenizer @@ -30,11 +30,11 @@ class ModelRequestData(NamedTuple): engine_args: EngineArgs - prompt: Optional[str] = None - prompt_token_ids: Optional[dict[str, list[int]]] = None - multi_modal_data: Optional[dict[str, Any]] = None - stop_token_ids: Optional[list[int]] = None - lora_requests: Optional[list[LoRARequest]] = None + prompt: str | None = None + prompt_token_ids: dict[str, list[int]] | None = None + multi_modal_data: dict[str, Any] | None = None + stop_token_ids: list[int] | None = None + lora_requests: list[LoRARequest] | None = None # NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on @@ -45,10 +45,12 @@ class ModelRequestData(NamedTuple): # Voxtral def run_voxtral(question: str, audio_count: int) -> ModelRequestData: from mistral_common.audio import Audio - from mistral_common.protocol.instruct.messages import ( + from mistral_common.protocol.instruct.chunk import ( AudioChunk, RawAudio, TextChunk, + ) + from mistral_common.protocol.instruct.messages import ( UserMessage, ) from mistral_common.protocol.instruct.request import ChatCompletionRequest diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 0076d4d30ee8..0b281fc41a34 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -33,7 +33,7 @@ from time import sleep from vllm import LLM, SamplingParams -from vllm.utils import get_open_port +from vllm.utils.network_utils import get_open_port def parse_args(): @@ -95,7 +95,7 @@ def parse_args(): parser.add_argument( "--compilation-config", type=int, - help=("Compilation optimization (O) level 0-3."), + help=("Compilation optimization (O) mode 0-3."), ) parser.add_argument( "--quantization", diff --git a/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py b/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py index 0abe7d161261..5b2acea4c945 100644 --- a/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py +++ b/examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py @@ -3,7 +3,7 @@ # ruff: noqa: E501 import logging from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -81,7 +81,7 @@ def start_load_kv(self, forward_context: ForwardContext, **kwargs) -> None: def get_finished( self, finished_req_ids: set[str] - ) -> tuple[Optional[set[str]], Optional[set[str]]]: + ) -> tuple[set[str] | None, set[str] | None]: if self._async_load: meta = self._get_connector_metadata() assert isinstance(meta, RogueSharedStorageConnectorMetadata) diff --git a/examples/offline_inference/logits_processor/custom.py b/examples/offline_inference/logits_processor/custom.py index 4112a498f37a..72e7ce24d7cc 100644 --- a/examples/offline_inference/logits_processor/custom.py +++ b/examples/offline_inference/logits_processor/custom.py @@ -33,8 +33,6 @@ class object. ------------------------------------------------------------ """ -from typing import Optional - import torch from vllm import LLM, SamplingParams @@ -58,7 +56,7 @@ def __init__( def is_argmax_invariant(self) -> bool: return False - def update_state(self, batch_update: Optional[BatchUpdate]): + def update_state(self, batch_update: BatchUpdate | None): process_dict_updates( self.req_info, batch_update, diff --git a/examples/offline_inference/logits_processor/custom_req.py b/examples/offline_inference/logits_processor/custom_req.py index 4c19bb4ce2ba..87cd7473fa9f 100644 --- a/examples/offline_inference/logits_processor/custom_req.py +++ b/examples/offline_inference/logits_processor/custom_req.py @@ -39,7 +39,7 @@ ------------------------------------------------------------ """ -from typing import Any, Optional +from typing import Any import torch @@ -82,7 +82,7 @@ def is_argmax_invariant(self) -> bool: def new_req_logits_processor( self, params: SamplingParams, - ) -> Optional[RequestLogitsProcessor]: + ) -> RequestLogitsProcessor | None: """This method returns a new request-level logits processor, customized to the `target_token` value associated with a particular request. @@ -96,7 +96,7 @@ def new_req_logits_processor( Returns: `Callable` request logits processor, or None """ - target_token: Optional[Any] = params.extra_args and params.extra_args.get( + target_token: Any | None = params.extra_args and params.extra_args.get( "target_token" ) if target_token is None: diff --git a/examples/offline_inference/logits_processor/custom_req_init.py b/examples/offline_inference/logits_processor/custom_req_init.py index 62947d122e01..3bb82a786040 100644 --- a/examples/offline_inference/logits_processor/custom_req_init.py +++ b/examples/offline_inference/logits_processor/custom_req_init.py @@ -41,8 +41,6 @@ device, the first and third requests would not repeat the same token. """ -from typing import Optional - import torch from vllm import LLM, SamplingParams @@ -91,7 +89,7 @@ def is_argmax_invariant(self) -> bool: def new_req_logits_processor( self, params: SamplingParams, - ) -> Optional[RequestLogitsProcessor]: + ) -> RequestLogitsProcessor | None: """This method returns a new request-level logits processor, customized to the `target_token` value associated with a particular request. diff --git a/examples/offline_inference/lora_with_quantization_inference.py b/examples/offline_inference/lora_with_quantization_inference.py index 00d4cb9eb4c4..dc5c6202fa57 100644 --- a/examples/offline_inference/lora_with_quantization_inference.py +++ b/examples/offline_inference/lora_with_quantization_inference.py @@ -8,7 +8,6 @@ """ import gc -from typing import Optional import torch from huggingface_hub import snapshot_download @@ -19,7 +18,7 @@ def create_test_prompts( lora_path: str, -) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]: +) -> list[tuple[str, SamplingParams, LoRARequest | None]]: return [ # this is an example of using quantization without LoRA ( @@ -56,7 +55,7 @@ def create_test_prompts( def process_requests( engine: LLMEngine, - test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]], + test_prompts: list[tuple[str, SamplingParams, LoRARequest | None]], ): """Continuously process a list of prompts and handle the outputs.""" request_id = 0 @@ -78,7 +77,7 @@ def process_requests( def initialize_engine( - model: str, quantization: str, lora_repo: Optional[str] + model: str, quantization: str, lora_repo: str | None ) -> LLMEngine: """Initialize the LLMEngine.""" diff --git a/examples/offline_inference/multilora_inference.py b/examples/offline_inference/multilora_inference.py index 6040683c68bc..6c23cf342e06 100644 --- a/examples/offline_inference/multilora_inference.py +++ b/examples/offline_inference/multilora_inference.py @@ -7,8 +7,6 @@ Requires HuggingFace credentials for access to Llama2. """ -from typing import Optional - from huggingface_hub import snapshot_download from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams @@ -17,7 +15,7 @@ def create_test_prompts( lora_path: str, -) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]: +) -> list[tuple[str, SamplingParams, LoRARequest | None]]: """Create a list of test prompts with their sampling parameters. 2 requests for base model, 4 requests for the LoRA. We define 2 @@ -68,7 +66,7 @@ def create_test_prompts( def process_requests( engine: LLMEngine, - test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]], + test_prompts: list[tuple[str, SamplingParams, LoRARequest | None]], ): """Continuously process a list of prompts and handle the outputs.""" request_id = 0 diff --git a/examples/offline_inference/openai_batch/README.md b/examples/offline_inference/openai_batch/README.md index 3c6f6c7a6c58..7d5a1af8f5a4 100644 --- a/examples/offline_inference/openai_batch/README.md +++ b/examples/offline_inference/openai_batch/README.md @@ -152,7 +152,9 @@ def generate_presigned_url(s3_client, client_method, method_parameters, expires_ """ try: url = s3_client.generate_presigned_url( - ClientMethod=client_method, Params=method_parameters, ExpiresIn=expires_in + ClientMethod=client_method, + Params=method_parameters, + ExpiresIn=expires_in, ) except ClientError: raise @@ -161,10 +163,16 @@ def generate_presigned_url(s3_client, client_method, method_parameters, expires_ s3_client = boto3.client("s3") input_url = generate_presigned_url( - s3_client, "get_object", {"Bucket": "MY_BUCKET", "Key": "MY_INPUT_FILE.jsonl"}, 3600 + s3_client, + "get_object", + {"Bucket": "MY_BUCKET", "Key": "MY_INPUT_FILE.jsonl"}, + expires_in=3600, ) output_url = generate_presigned_url( - s3_client, "put_object", {"Bucket": "MY_BUCKET", "Key": "MY_OUTPUT_FILE.jsonl"}, 3600 + s3_client, + "put_object", + {"Bucket": "MY_BUCKET", "Key": "MY_OUTPUT_FILE.jsonl"}, + expires_in=3600, ) print(f"{input_url=}") print(f"{output_url=}") diff --git a/examples/offline_inference/pooling/README.md b/examples/offline_inference/pooling/README.md index 79afbd9cfac4..cd9717122b16 100644 --- a/examples/offline_inference/pooling/README.md +++ b/examples/offline_inference/pooling/README.md @@ -14,7 +14,7 @@ python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_na ## Embed jina_embeddings_v3 usage -Only text matching task is supported for now. See <gh-pr:16120> +Only text matching task is supported for now. See <https://github.com/vllm-project/vllm/pull/16120> ```bash python examples/offline_inference/pooling/embed_jina_embeddings_v3.py @@ -26,6 +26,12 @@ python examples/offline_inference/pooling/embed_jina_embeddings_v3.py python examples/offline_inference/pooling/embed_matryoshka_fy.py ``` +## Multi vector retrieval usage + +```bash +python examples/offline_inference/pooling/multi_vector_retrieval.py +``` + ## Named Entity Recognition (NER) usage ```bash diff --git a/examples/offline_inference/pooling/multi_vector_retrieval.py b/examples/offline_inference/pooling/multi_vector_retrieval.py new file mode 100644 index 000000000000..8b8892117d37 --- /dev/null +++ b/examples/offline_inference/pooling/multi_vector_retrieval.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from argparse import Namespace + +from vllm import LLM, EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def parse_args(): + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + # Set example specific arguments + parser.set_defaults( + model="BAAI/bge-m3", + runner="pooling", + enforce_eager=True, + ) + return parser.parse_args() + + +def main(args: Namespace): + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # Create an LLM. + # You should pass runner="pooling" for embedding models + llm = LLM(**vars(args)) + + # Generate embedding. The output is a list of EmbeddingRequestOutputs. + outputs = llm.embed(prompts) + + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for prompt, output in zip(prompts, outputs): + embeds = output.outputs.embedding + print(len(embeds)) + + # Generate embedding for each token. The output is a list of PoolingRequestOutput. + outputs = llm.encode(prompts, pooling_task="token_embed") + + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for prompt, output in zip(prompts, outputs): + multi_vector = output.outputs.data + print(multi_vector.shape) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index 1a5879a6d35f..b093c77c00b7 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -3,7 +3,6 @@ import argparse import datetime import os -from typing import Union import albumentations import numpy as np @@ -50,6 +49,7 @@ def __init__(self, model): dtype="float16", enforce_eager=True, model_impl="terratorch", + enable_mm_embeds=True, ) def run(self, input_data, location_coords): @@ -64,7 +64,7 @@ def run(self, input_data, location_coords): } prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data} - outputs = self.model.encode(prompt, use_tqdm=False) + outputs = self.model.encode(prompt, pooling_task="plugin", use_tqdm=False) return outputs[0].outputs.data @@ -160,7 +160,7 @@ def load_example( file_paths: list[str], mean: list[float] = None, std: list[float] = None, - indices: Union[list[int], None] = None, + indices: list[int] | None = None, ): """Build an input example by loading images in *file_paths*. diff --git a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py index 418c40645f9f..b8637b89e08f 100644 --- a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py +++ b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py @@ -6,14 +6,14 @@ import torch from vllm import LLM -from vllm.pooling_params import PoolingParams # This example shows how to perform an offline inference that generates # multimodal data. In this specific case this example will take a geotiff # image as input, process it using the multimodal data processor, and # perform inference. -# Requirement - install plugin at: -# https://github.com/christian-pinto/prithvi_io_processor_plugin +# Requirements: +# - install TerraTorch v1.1 (or later): +# pip install terratorch>=v1.1 def main(): @@ -36,15 +36,12 @@ def main(): # to avoid the model going OOM. # The maximum number depends on the available GPU memory max_num_seqs=32, - io_processor_plugin="prithvi_to_tiff", + io_processor_plugin="terratorch_segmentation", model_impl="terratorch", + enable_mm_embeds=True, ) - pooling_params = PoolingParams(task="encode", softmax=False) - pooler_output = llm.encode( - img_prompt, - pooling_params=pooling_params, - ) + pooler_output = llm.encode(img_prompt, pooling_task="plugin") output = pooler_output[0].outputs print(output) diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py index ed974b90b57e..0c09e603271d 100644 --- a/examples/offline_inference/rlhf.py +++ b/examples/offline_inference/rlhf.py @@ -38,7 +38,7 @@ from transformers import AutoModelForCausalLM from vllm import LLM, SamplingParams -from vllm.utils import get_ip, get_open_port +from vllm.utils.network_utils import get_ip, get_open_port class MyLLM(LLM): diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py index c0e60b979340..13def88439ef 100644 --- a/examples/offline_inference/rlhf_utils.py +++ b/examples/offline_inference/rlhf_utils.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import gc -from typing import Callable, Optional, TypedDict +from collections.abc import Callable +from typing import TypedDict import torch import zmq @@ -71,7 +72,7 @@ def check_weights_changed(self): def rebuild_ipc( - handle: tuple[Callable, tuple], device_id: Optional[int] = None + handle: tuple[Callable, tuple], device_id: int | None = None ) -> torch.Tensor: func, args = handle list_args = list(args) @@ -109,7 +110,7 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]): self._zmq_ctx = zmq.Context() socket = self._zmq_ctx.socket(zmq.REP) socket.connect(zmq_handles[self.report_device_id()]) - buffer: Optional[torch.Tensor] = None + buffer: torch.Tensor | None = None while True: payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = ( socket.recv_pyobj() diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 9fd9da3b0855..7668b10916ac 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -12,7 +12,7 @@ import random from contextlib import contextmanager from dataclasses import asdict -from typing import NamedTuple, Optional +from typing import NamedTuple from huggingface_hub import snapshot_download from transformers import AutoTokenizer @@ -28,8 +28,9 @@ class ModelRequestData(NamedTuple): engine_args: EngineArgs prompts: list[str] - stop_token_ids: Optional[list[int]] = None - lora_requests: Optional[list[LoRARequest]] = None + stop_token_ids: list[int] | None = None + lora_requests: list[LoRARequest] | None = None + sampling_params: list[SamplingParams] | None = None # NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on @@ -90,16 +91,25 @@ def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData: ) -# BLIP-2 -def run_blip2(questions: list[str], modality: str) -> ModelRequestData: +# Bee-8B +def run_bee(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" + model_name = "Open-Bee/Bee-8B-RL" + + prompts = [ + ( + f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n<image>\n{question}<|im_end|>" + f"<|im_start|>assistant\n<think>\n" + ) + for question in questions + ] - # BLIP-2 prompt format is inaccurate on HuggingFace model repository. - # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa - prompts = [f"Question: {question} Answer:" for question in questions] engine_args = EngineArgs( - model="Salesforce/blip2-opt-2.7b", + model=model_name, + max_model_len=16384, limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, ) return ModelRequestData( @@ -108,15 +118,15 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData: ) -# Chameleon -def run_chameleon(questions: list[str], modality: str) -> ModelRequestData: +# BLIP-2 +def run_blip2(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" - prompts = [f"{question}<image>" for question in questions] + # BLIP-2 prompt format is inaccurate on HuggingFace model repository. + # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa + prompts = [f"Question: {question} Answer:" for question in questions] engine_args = EngineArgs( - model="facebook/chameleon-7b", - max_model_len=4096, - max_num_seqs=2, + model="Salesforce/blip2-opt-2.7b", limit_mm_per_prompt={modality: 1}, ) @@ -126,15 +136,16 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData: ) -# Dots-OCR -def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData: +# Chameleon +def run_chameleon(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" - prompts = [f"<|img|><|imgpad|><|endofimg|>{question}" for question in questions] + prompts = [f"{question}<image>" for question in questions] engine_args = EngineArgs( - model="rednote-hilab/dots.ocr", + model="facebook/chameleon-7b", + max_model_len=4096, + max_num_seqs=2, limit_mm_per_prompt={modality: 1}, - trust_remote_code=True, ) return ModelRequestData( @@ -190,6 +201,66 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData: ) +def run_deepseek_ocr(questions: list[str], modality: str) -> ModelRequestData: + from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor + + assert modality == "image" + + model_name = "deepseek-ai/DeepSeek-OCR" + + engine_args = EngineArgs( + model=model_name, + limit_mm_per_prompt={modality: 1}, + logits_processors=[NGramPerReqLogitsProcessor], + ) + + # deepseek-ocr use plain prompt template + prompts = [f"<image>\n{question}" for question in questions] + + # The following sampling params config is taken from + # the official Deepseek-OCR inference example. + # (IMPORTANT) Use the custom logits processor and avoid skipping + # special tokens for this model for the optimal OCR performance. + sampling_params = [ + SamplingParams( + temperature=0.0, + max_tokens=8192, + # ngram logit processor args + extra_args=dict( + ngram_size=30, + window_size=90, + # whitelist: <td>, </td> + whitelist_token_ids={128821, 128822}, + ), + skip_special_tokens=False, + ) + for _ in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + sampling_params=sampling_params, + ) + + +# Dots-OCR +def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + prompts = [f"<|img|><|imgpad|><|endofimg|>{question}" for question in questions] + engine_args = EngineArgs( + model="rednote-hilab/dots.ocr", + limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Ernie4.5-VL def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData: model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT" @@ -733,6 +804,26 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: ) +# LightOnOCR +def run_lightonocr(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + prompts = [ + "<|im_start|>system<|im_end|>\n<|im_start|>user\n<|image_pad|><|im_end|>\n<|im_start|>assistant\n" + for _ in questions + ] + + engine_args = EngineArgs( + model="lightonai/LightOnOCR-1B", + limit_mm_per_prompt={modality: 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + def run_llama4(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1687,11 +1778,13 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData: model_example_map = { "aria": run_aria, "aya_vision": run_aya_vision, + "bee": run_bee, "blip-2": run_blip2, "chameleon": run_chameleon, - "dots_ocr": run_dots_ocr, "command_a_vision": run_command_a_vision, "deepseek_vl_v2": run_deepseek_vl2, + "deepseek_ocr": run_deepseek_ocr, + "dots_ocr": run_dots_ocr, "ernie45_vl": run_ernie45_vl, "fuyu": run_fuyu, "gemma3": run_gemma3, @@ -1708,6 +1801,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData: "keye_vl": run_keye_vl, "keye_vl1_5": run_keye_vl1_5, "kimi_vl": run_kimi_vl, + "lightonocr": run_lightonocr, "llama4": run_llama4, "llava": run_llava, "llava-next": run_llava_next, @@ -1953,8 +2047,12 @@ def main(args): # We set temperature to 0.2 so that outputs can be different # even when all prompts are identical when running batch inference. - sampling_params = SamplingParams( - temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids + sampling_params = ( + SamplingParams( + temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids + ) + if req_data.sampling_params is None + else req_data.sampling_params ) assert args.num_prompts > 0 diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index c37d40a23ac2..b9115121a946 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -9,7 +9,7 @@ import os from argparse import Namespace from dataclasses import asdict -from typing import NamedTuple, Optional +from typing import NamedTuple from huggingface_hub import snapshot_download from PIL.Image import Image @@ -41,9 +41,10 @@ class ModelRequestData(NamedTuple): engine_args: EngineArgs prompt: str image_data: list[Image] - stop_token_ids: Optional[list[int]] = None - chat_template: Optional[str] = None - lora_requests: Optional[list[LoRARequest]] = None + stop_token_ids: list[int] | None = None + chat_template: str | None = None + lora_requests: list[LoRARequest] | None = None + sampling_params: SamplingParams | None = None # NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on @@ -107,6 +108,41 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_bee(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "Open-Bee/Bee-8B-RL" + + engine_args = EngineArgs( + model=model_name, + max_model_len=16384, + max_num_seqs=16, + limit_mm_per_prompt={"image": len(image_urls)}, + trust_remote_code=True, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_command_a_vision(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "CohereLabs/command-a-vision-07-2025" @@ -166,6 +202,46 @@ def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_deepseek_ocr(question: str, image_urls: list[str]) -> ModelRequestData: + from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor + + model_name = "deepseek-ai/DeepSeek-OCR" + + engine_args = EngineArgs( + model=model_name, + max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}, + logits_processors=[NGramPerReqLogitsProcessor], + ) + + placeholder = "<image>\n" * len(image_urls) + prompt = placeholder + question + + # The following sampling params config is taken from + # the official Deepseek-OCR inference example. + # (IMPORTANT) Use the custom logits processor and avoid skipping + # special tokens for this model for the optimal OCR performance. + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=8192, + # ngram logit processor args + extra_args=dict( + ngram_size=30, + window_size=90, + # whitelist: <td>, </td> + whitelist_token_ids={128821, 128822}, + ), + skip_special_tokens=False, + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + sampling_params=sampling_params, + ) + + def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "google/gemma-3-4b-it" @@ -1215,8 +1291,10 @@ def load_glm4_5v_fp8(question: str, image_urls: list[str]) -> ModelRequestData: model_example_map = { "aria": load_aria, "aya_vision": load_aya_vision, + "bee": load_bee, "command_a_vision": load_command_a_vision, "deepseek_vl_v2": load_deepseek_vl2, + "deepseek_ocr": load_deepseek_ocr, "gemma3": load_gemma3, "h2ovl_chat": load_h2ovl, "hyperclovax_seed_vision": load_hyperclovax_seed_vision, @@ -1251,7 +1329,7 @@ def load_glm4_5v_fp8(question: str, image_urls: list[str]) -> ModelRequestData: } -def run_generate(model, question: str, image_urls: list[str], seed: Optional[int]): +def run_generate(model, question: str, image_urls: list[str], seed: int | None): req_data = model_example_map[model](question, image_urls) engine_args = asdict(req_data.engine_args) | {"seed": args.seed} @@ -1277,7 +1355,7 @@ def run_generate(model, question: str, image_urls: list[str], seed: Optional[int print("-" * 50) -def run_chat(model: str, question: str, image_urls: list[str], seed: Optional[int]): +def run_chat(model: str, question: str, image_urls: list[str], seed: int | None): req_data = model_example_map[model](question, image_urls) # Disable other modalities to save memory @@ -1289,8 +1367,12 @@ def run_chat(model: str, question: str, image_urls: list[str], seed: Optional[in engine_args = asdict(req_data.engine_args) | {"seed": seed} llm = LLM(**engine_args) - sampling_params = SamplingParams( - temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids + sampling_params = ( + SamplingParams( + temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids + ) + if req_data.sampling_params is None + else req_data.sampling_params ) outputs = llm.chat( [ diff --git a/examples/offline_inference/vision_language_pooling.py b/examples/offline_inference/vision_language_pooling.py index 33ffb59014d8..cf4695c2545f 100644 --- a/examples/offline_inference/vision_language_pooling.py +++ b/examples/offline_inference/vision_language_pooling.py @@ -11,7 +11,7 @@ from argparse import Namespace from dataclasses import asdict from pathlib import Path -from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args +from typing import Literal, NamedTuple, TypeAlias, TypedDict, get_args from PIL.Image import Image @@ -47,15 +47,15 @@ class TextImagesQuery(TypedDict): QueryModality = Literal["text", "image", "text+image", "text+images"] -Query = Union[TextQuery, ImageQuery, TextImageQuery, TextImagesQuery] +Query: TypeAlias = TextQuery | ImageQuery | TextImageQuery | TextImagesQuery class ModelRequestData(NamedTuple): engine_args: EngineArgs - prompt: Optional[str] = None - image: Optional[Image] = None - query: Optional[str] = None - documents: Optional[ScoreMultiModalParam] = None + prompt: str | None = None + image: Image | None = None + query: str | None = None + documents: ScoreMultiModalParam | None = None def run_clip(query: Query) -> ModelRequestData: @@ -110,6 +110,53 @@ def run_e5_v(query: Query) -> ModelRequestData: ) +def run_jinavl_reranker(query: Query) -> ModelRequestData: + if query["modality"] != "text+images": + raise ValueError(f"Unsupported query modality: '{query['modality']}'") + + engine_args = EngineArgs( + model="jinaai/jina-reranker-m0", + runner="pooling", + max_model_len=32768, + trust_remote_code=True, + mm_processor_kwargs={ + "min_pixels": 3136, + "max_pixels": 602112, + }, + limit_mm_per_prompt={"image": 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + query=query["text"], + documents=query["image"], + ) + + +def run_siglip(query: Query) -> ModelRequestData: + if query["modality"] == "text": + prompt = query["text"] + image = None + elif query["modality"] == "image": + prompt = "" # For image input, make sure that the prompt text is empty + image = query["image"] + else: + modality = query["modality"] + raise ValueError(f"Unsupported query modality: '{modality}'") + + engine_args = EngineArgs( + model="google/siglip-base-patch16-224", + runner="pooling", + limit_mm_per_prompt={"image": 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image=image, + ) + + def _get_vlm2vec_prompt_image(query: Query, image_token: str): if query["modality"] == "text": text = query["text"] @@ -211,29 +258,6 @@ def run_vlm2vec_qwen2vl(query: Query) -> ModelRequestData: ) -def run_jinavl_reranker(query: Query) -> ModelRequestData: - if query["modality"] != "text+images": - raise ValueError(f"Unsupported query modality: '{query['modality']}'") - - engine_args = EngineArgs( - model="jinaai/jina-reranker-m0", - runner="pooling", - max_model_len=32768, - trust_remote_code=True, - mm_processor_kwargs={ - "min_pixels": 3136, - "max_pixels": 602112, - }, - limit_mm_per_prompt={"image": 1}, - ) - - return ModelRequestData( - engine_args=engine_args, - query=query["text"], - documents=query["image"], - ) - - def get_query(modality: QueryModality): if modality == "text": return TextQuery(modality="text", text="A dog sitting in the grass") @@ -281,7 +305,7 @@ def get_query(modality: QueryModality): raise ValueError(msg) -def run_encode(model: str, modality: QueryModality, seed: Optional[int]): +def run_encode(model: str, modality: QueryModality, seed: int | None): query = get_query(modality) req_data = model_example_map[model](query) @@ -311,7 +335,7 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]): print("-" * 50) -def run_score(model: str, modality: QueryModality, seed: Optional[int]): +def run_score(model: str, modality: QueryModality, seed: int | None): query = get_query(modality) req_data = model_example_map[model](query) @@ -328,9 +352,10 @@ def run_score(model: str, modality: QueryModality, seed: Optional[int]): model_example_map = { "clip": run_clip, "e5_v": run_e5_v, + "jinavl_reranker": run_jinavl_reranker, + "siglip": run_siglip, "vlm2vec_phi3v": run_vlm2vec_phi3v, "vlm2vec_qwen2vl": run_vlm2vec_qwen2vl, - "jinavl_reranker": run_jinavl_reranker, } diff --git a/examples/online_serving/dashboards/perses/performance_statistics.yaml b/examples/online_serving/dashboards/perses/performance_statistics.yaml index 2e8d24c3324b..8030fe2f00a9 100644 --- a/examples/online_serving/dashboards/perses/performance_statistics.yaml +++ b/examples/online_serving/dashboards/perses/performance_statistics.yaml @@ -530,7 +530,7 @@ spec: name: accelerators-thanos-querier-datasource # Multiply by 100 so we can read it as a percentage without setting a unit (avoids CUE unit conflicts) query: > - 100 * avg(vllm:gpu_cache_usage_perc) + 100 * avg(vllm:kv_cache_usage_perc) "18": kind: Panel diff --git a/examples/online_serving/dashboards/perses/query_statistics.yaml b/examples/online_serving/dashboards/perses/query_statistics.yaml index 28109aae8151..ad8e047f6dfe 100644 --- a/examples/online_serving/dashboards/perses/query_statistics.yaml +++ b/examples/online_serving/dashboards/perses/query_statistics.yaml @@ -98,7 +98,7 @@ spec: kind: PrometheusTimeSeriesQuery spec: datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } - query: avg(vllm:gpu_cache_usage_perc{namespace="$NS",service="$SVC"}) or vector(0) + query: avg(vllm:kv_cache_usage_perc{namespace="$NS",service="$SVC"}) or vector(0) minStep: "15s" core_running_ts: @@ -168,7 +168,7 @@ spec: spec: datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } # multiply by 100 to present percentage; omit format.unit to avoid schema conflicts - query: (avg(vllm:gpu_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0) + query: (avg(vllm:kv_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0) minStep: "15s" core_kv_usage_pct_ts: @@ -187,7 +187,7 @@ spec: kind: PrometheusTimeSeriesQuery spec: datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } - query: (avg by (service) (vllm:gpu_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0) + query: (avg by (service) (vllm:kv_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0) minStep: "15s" # --- Per-Pod breakdowns (works on Simulator & Real) --- @@ -246,7 +246,7 @@ spec: spec: datasource: { kind: PrometheusDatasource, name: accelerators-thanos-querier-datasource } # if your exporter labels kv metric with pod (the sim does), this works; otherwise it will just return empty - query: (avg by (pod) (vllm:gpu_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0) + query: (avg by (pod) (vllm:kv_cache_usage_perc{namespace="$NS",service="$SVC"}) * 100) or vector(0) minStep: "15s" # --- Real vLLM only (zeros on simulator) --- diff --git a/examples/online_serving/disaggregated_serving/disagg_proxy_demo.py b/examples/online_serving/disaggregated_serving/disagg_proxy_demo.py index 1df11d9d8495..2b8482ec717a 100644 --- a/examples/online_serving/disaggregated_serving/disagg_proxy_demo.py +++ b/examples/online_serving/disaggregated_serving/disagg_proxy_demo.py @@ -23,7 +23,7 @@ import os import sys from abc import ABC, abstractmethod -from typing import Callable, Optional +from collections.abc import Callable import aiohttp import requests @@ -49,12 +49,9 @@ def __init__( decode_instances: list[str], model: str, scheduling_policy: SchedulingPolicy, - custom_create_completion: Optional[ - Callable[[Request], StreamingResponse] - ] = None, - custom_create_chat_completion: Optional[ - Callable[[Request], StreamingResponse] - ] = None, + custom_create_completion: Callable[[Request], StreamingResponse] | None = None, + custom_create_chat_completion: Callable[[Request], StreamingResponse] + | None = None, ): self.prefill_instances = prefill_instances self.decode_instances = decode_instances @@ -348,9 +345,9 @@ class ProxyServer: def __init__( self, args: argparse.Namespace, - scheduling_policy: Optional[SchedulingPolicy] = None, - create_completion: Optional[Callable[[Request], StreamingResponse]] = None, - create_chat_completion: Optional[Callable[[Request], StreamingResponse]] = None, + scheduling_policy: SchedulingPolicy | None = None, + create_completion: Callable[[Request], StreamingResponse] | None = None, + create_chat_completion: Callable[[Request], StreamingResponse] | None = None, ): self.validate_parsed_serve_args(args) self.port = args.port diff --git a/examples/online_serving/kv_events_subscriber.py b/examples/online_serving/kv_events_subscriber.py index f4b79b5e1302..19f6bd572610 100644 --- a/examples/online_serving/kv_events_subscriber.py +++ b/examples/online_serving/kv_events_subscriber.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional, Union +from typing import Any import msgspec import zmq @@ -25,16 +25,16 @@ class KVCacheEvent( class BlockStored(KVCacheEvent): block_hashes: list[ExternalBlockHash] - parent_block_hash: Optional[ExternalBlockHash] + parent_block_hash: ExternalBlockHash | None token_ids: list[int] block_size: int - lora_id: Optional[int] - medium: Optional[str] + lora_id: int | None + medium: str | None class BlockRemoved(KVCacheEvent): block_hashes: list[ExternalBlockHash] - medium: Optional[str] + medium: str | None class AllBlocksCleared(KVCacheEvent): @@ -42,7 +42,7 @@ class AllBlocksCleared(KVCacheEvent): class KVEventBatch(EventBatch): - events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]] + events: list[BlockStored | BlockRemoved | AllBlocksCleared] def process_event(event_batch): diff --git a/examples/online_serving/multi_instance_data_parallel.py b/examples/online_serving/multi_instance_data_parallel.py index cb230913a422..04d21e048940 100644 --- a/examples/online_serving/multi_instance_data_parallel.py +++ b/examples/online_serving/multi_instance_data_parallel.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio -from typing import Optional +import threading from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams +from vllm.v1.metrics.loggers import AggregatedLoggingStatLogger """ To run this example, run the following commands simultaneously with @@ -22,37 +23,64 @@ """ +def _do_background_logging(engine, interval, stop_event): + try: + while not stop_event.is_set(): + asyncio.run(engine.do_log_stats()) + stop_event.wait(interval) + except Exception as e: + print(f"vLLM background logging shutdown: {e}") + pass + + async def main(): engine_args = AsyncEngineArgs( model="ibm-research/PowerMoE-3b", data_parallel_size=2, + tensor_parallel_size=1, dtype="auto", max_model_len=2048, data_parallel_address="127.0.0.1", data_parallel_rpc_port=62300, data_parallel_size_local=1, enforce_eager=True, + enable_log_requests=True, + disable_custom_all_reduce=True, ) - engine_client = AsyncLLMEngine.from_engine_args(engine_args) - + engine_client = AsyncLLMEngine.from_engine_args( + engine_args, + # Example: Using aggregated logger + stat_loggers=[AggregatedLoggingStatLogger], + ) + stop_logging_event = threading.Event() + logging_thread = threading.Thread( + target=_do_background_logging, + args=(engine_client, 5, stop_logging_event), + daemon=True, + ) + logging_thread.start() sampling_params = SamplingParams( temperature=0.7, top_p=0.9, max_tokens=100, ) + num_prompts = 10 + for i in range(num_prompts): + prompt = "Who won the 2004 World Series?" + final_output: RequestOutput | None = None + async for output in engine_client.generate( + prompt=prompt, + sampling_params=sampling_params, + request_id=f"abcdef-{i}", + data_parallel_rank=1, + ): + final_output = output + if final_output: + print(final_output.outputs[0].text) - prompt = "Who won the 2004 World Series?" - final_output: Optional[RequestOutput] = None - async for output in engine_client.generate( - prompt=prompt, - sampling_params=sampling_params, - request_id="abcdef", - data_parallel_rank=1, - ): - final_output = output - if final_output: - print(final_output.outputs[0].text) + stop_logging_event.set() + logging_thread.join() if __name__ == "__main__": diff --git a/examples/online_serving/pooling/README.md b/examples/online_serving/pooling/README.md index 2c271b6a32bc..3b6da20d5f0f 100644 --- a/examples/online_serving/pooling/README.md +++ b/examples/online_serving/pooling/README.md @@ -6,16 +6,34 @@ python examples/online_serving/pooling/cohere_rerank_client.py ``` +## Embedding requests base64 encoding_format usage + +```bash +python examples/online_serving/pooling/embedding_requests_base64_client.py +``` + +## Embedding requests bytes encoding_format usage + +```bash +python examples/online_serving/pooling/embedding_requests_bytes_client.py +``` + ## Jinaai rerank usage ```bash python examples/online_serving/pooling/jinaai_rerank_client.py ``` +## Multi vector retrieval usage + +```bash +python examples/online_serving/pooling/multi_vector_retrieval_client.py +``` + ## Named Entity Recognition (NER) usage ```bash -python examples/online_serving/pooling/ner.py +python examples/online_serving/pooling/ner_client.py ``` ## Openai chat embedding for multimodal usage diff --git a/examples/online_serving/pooling/cohere_rerank_client.py b/examples/online_serving/pooling/cohere_rerank_client.py index 63c9ff9e9398..b32209967be9 100644 --- a/examples/online_serving/pooling/cohere_rerank_client.py +++ b/examples/online_serving/pooling/cohere_rerank_client.py @@ -8,8 +8,6 @@ run: vllm serve BAAI/bge-reranker-base """ -from typing import Union - import cohere from cohere import Client, ClientV2 @@ -25,7 +23,7 @@ def cohere_rerank( - client: Union[Client, ClientV2], model: str, query: str, documents: list[str] + client: Client | ClientV2, model: str, query: str, documents: list[str] ) -> dict: return client.rerank(model=model, query=query, documents=documents) diff --git a/examples/online_serving/pooling/embedding_requests_base64_client.py b/examples/online_serving/pooling/embedding_requests_base64_client.py new file mode 100644 index 000000000000..4c2399b58c11 --- /dev/null +++ b/examples/online_serving/pooling/embedding_requests_base64_client.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Example Python client for embedding API using vLLM API server +NOTE: + start a supported embeddings model server with `vllm serve`, e.g. + vllm serve intfloat/e5-small +""" + +import argparse +import base64 + +import requests +import torch + +from vllm.utils.serial_utils import ( + EMBED_DTYPE_TO_TORCH_DTYPE, + ENDIANNESS, + binary2tensor, +) + + +def post_http_request(prompt: dict, api_url: str) -> requests.Response: + headers = {"User-Agent": "Test Client"} + response = requests.post(api_url, headers=headers, json=prompt) + return response + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--model", type=str, default="intfloat/e5-small") + + return parser.parse_args() + + +def main(args): + api_url = f"http://{args.host}:{args.port}/v1/embeddings" + model_name = args.model + + # The OpenAI client does not support the embed_dtype and endianness parameters. + for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE: + for endianness in ENDIANNESS: + prompt = { + "model": model_name, + "input": "vLLM is great!", + "encoding_format": "base64", + "embed_dtype": embed_dtype, + "endianness": endianness, + } + response = post_http_request(prompt=prompt, api_url=api_url) + + embedding = [] + for data in response.json()["data"]: + binary = base64.b64decode(data["embedding"]) + tensor = binary2tensor(binary, (-1,), embed_dtype, endianness) + embedding.append(tensor.to(torch.float32)) + embedding = torch.cat(embedding) + print(embed_dtype, endianness, embedding.shape) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/online_serving/pooling/embedding_requests_bytes_client.py b/examples/online_serving/pooling/embedding_requests_bytes_client.py new file mode 100644 index 000000000000..c2832f1b54ce --- /dev/null +++ b/examples/online_serving/pooling/embedding_requests_bytes_client.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Example Python client for embedding API using vLLM API server +NOTE: + start a supported embeddings model server with `vllm serve`, e.g. + vllm serve intfloat/e5-small +""" + +import argparse +import json + +import requests +import torch + +from vllm.utils.serial_utils import ( + EMBED_DTYPE_TO_TORCH_DTYPE, + ENDIANNESS, + MetadataItem, + decode_pooling_output, +) + + +def post_http_request(prompt: dict, api_url: str) -> requests.Response: + headers = {"User-Agent": "Test Client"} + response = requests.post(api_url, headers=headers, json=prompt) + return response + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--model", type=str, default="intfloat/e5-small") + + return parser.parse_args() + + +def main(args): + api_url = f"http://{args.host}:{args.port}/v1/embeddings" + model_name = args.model + + # The OpenAI client does not support the bytes encoding_format. + # The OpenAI client does not support the embed_dtype and endianness parameters. + for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE: + for endianness in ENDIANNESS: + prompt = { + "model": model_name, + "input": "vLLM is great!", + "encoding_format": "bytes", + "embed_dtype": embed_dtype, + "endianness": endianness, + } + response = post_http_request(prompt=prompt, api_url=api_url) + metadata = json.loads(response.headers["metadata"]) + body = response.content + items = [MetadataItem(**x) for x in metadata["data"]] + + embedding = decode_pooling_output(items=items, body=body) + embedding = [x.to(torch.float32) for x in embedding] + embedding = torch.cat(embedding) + print(embed_dtype, endianness, embedding.shape) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/online_serving/pooling/multi_vector_retrieval_client.py b/examples/online_serving/pooling/multi_vector_retrieval_client.py new file mode 100644 index 000000000000..ef8c4745aa53 --- /dev/null +++ b/examples/online_serving/pooling/multi_vector_retrieval_client.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Example online usage of Pooling API for multi vector retrieval. + +Run `vllm serve <model> --runner pooling` +to start up the server in vLLM. e.g. + +vllm serve BAAI/bge-m3 +""" + +import argparse + +import requests +import torch + + +def post_http_request(prompt: dict, api_url: str) -> requests.Response: + headers = {"User-Agent": "Test Client"} + response = requests.post(api_url, headers=headers, json=prompt) + return response + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--model", type=str, default="BAAI/bge-m3") + + return parser.parse_args() + + +def main(args): + api_url = f"http://{args.host}:{args.port}/pooling" + model_name = args.model + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + prompt = {"model": model_name, "input": prompts} + + pooling_response = post_http_request(prompt=prompt, api_url=api_url) + for output in pooling_response.json()["data"]: + multi_vector = torch.tensor(output["data"]) + print(multi_vector.shape) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/online_serving/pooling/ner.py b/examples/online_serving/pooling/ner_client.py similarity index 100% rename from examples/online_serving/pooling/ner.py rename to examples/online_serving/pooling/ner_client.py diff --git a/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py b/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py index 16ac4378c686..261b810ce5d0 100644 --- a/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py +++ b/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py @@ -9,7 +9,7 @@ import argparse import base64 import io -from typing import Literal, Union +from typing import Literal from openai import OpenAI from openai._types import NOT_GIVEN, NotGiven @@ -29,7 +29,7 @@ def create_chat_embeddings( *, messages: list[ChatCompletionMessageParam], model: str, - encoding_format: Union[Literal["base64", "float"], NotGiven] = NOT_GIVEN, + encoding_format: Literal["base64", "float"] | NotGiven = NOT_GIVEN, ) -> CreateEmbeddingResponse: """ Convenience function for accessing vLLM's Chat Embeddings API, @@ -83,25 +83,29 @@ def run_clip(client: OpenAI, model: str): print("Text embedding output:", response.data[0].embedding) -def run_vlm2vec(client: OpenAI, model: str): +def run_dse_qwen2_vl(client: OpenAI, model: str): """ Start the server using: - vllm serve TIGER-Lab/VLM2Vec-Full \ + vllm serve MrLight/dse-qwen2-2b-mrl-v1 \ --runner pooling \ --trust-remote-code \ - --max-model-len 4096 \ - --chat-template examples/template_vlm2vec_phi3v.jinja + --max-model-len 8192 \ + --chat-template examples/template_dse_qwen2_vl.jinja """ - response = create_chat_embeddings( client, messages=[ { "role": "user", "content": [ - {"type": "image_url", "image_url": {"url": image_url}}, - {"type": "text", "text": "Represent the given image."}, + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + }, + {"type": "text", "text": "What is shown in this image?"}, ], } ], @@ -111,17 +115,26 @@ def run_vlm2vec(client: OpenAI, model: str): print("Image embedding output:", response.data[0].embedding) + # MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image + # of the minimum input size + buffer = io.BytesIO() + image_placeholder = Image.new("RGB", (56, 56)) + image_placeholder.save(buffer, "png") + buffer.seek(0) + image_placeholder = base64.b64encode(buffer.read()).decode("utf-8") response = create_chat_embeddings( client, messages=[ { "role": "user", "content": [ - {"type": "image_url", "image_url": {"url": image_url}}, { - "type": "text", - "text": "Represent the given image with the following question: What is in the image.", + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_placeholder}", + }, }, + {"type": "text", "text": "Query: What is the weather like today?"}, ], } ], @@ -129,7 +142,16 @@ def run_vlm2vec(client: OpenAI, model: str): encoding_format="float", ) - print("Image+Text embedding output:", response.data[0].embedding) + print("Text embedding output:", response.data[0].embedding) + + +def run_siglip(client: OpenAI, model: str): + """ + Start the server using: + + vllm serve google/siglip-base-patch16-224 \ + --runner pooling + """ response = create_chat_embeddings( client, @@ -137,7 +159,23 @@ def run_vlm2vec(client: OpenAI, model: str): { "role": "user", "content": [ - {"type": "text", "text": "A cat and a dog"}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + } + ], + model=model, + encoding_format="float", + ) + + print("Image embedding output:", response.data[0].embedding) + + response = create_chat_embeddings( + client, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "a photo of a cat"}, ], } ], @@ -148,29 +186,25 @@ def run_vlm2vec(client: OpenAI, model: str): print("Text embedding output:", response.data[0].embedding) -def run_dse_qwen2_vl(client: OpenAI, model: str): +def run_vlm2vec(client: OpenAI, model: str): """ Start the server using: - vllm serve MrLight/dse-qwen2-2b-mrl-v1 \ + vllm serve TIGER-Lab/VLM2Vec-Full \ --runner pooling \ --trust-remote-code \ - --max-model-len 8192 \ - --chat-template examples/template_dse_qwen2_vl.jinja + --max-model-len 4096 \ + --chat-template examples/template_vlm2vec_phi3v.jinja """ + response = create_chat_embeddings( client, messages=[ { "role": "user", "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url, - }, - }, - {"type": "text", "text": "What is shown in this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Represent the given image."}, ], } ], @@ -180,26 +214,33 @@ def run_dse_qwen2_vl(client: OpenAI, model: str): print("Image embedding output:", response.data[0].embedding) - # MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image - # of the minimum input size - buffer = io.BytesIO() - image_placeholder = Image.new("RGB", (56, 56)) - image_placeholder.save(buffer, "png") - buffer.seek(0) - image_placeholder = base64.b64encode(buffer.read()).decode("utf-8") response = create_chat_embeddings( client, messages=[ { "role": "user", "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_placeholder}", - }, + "type": "text", + "text": "Represent the given image with the following question: What is in the image.", }, - {"type": "text", "text": "Query: What is the weather like today?"}, + ], + } + ], + model=model, + encoding_format="float", + ) + + print("Image+Text embedding output:", response.data[0].embedding) + + response = create_chat_embeddings( + client, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "A cat and a dog"}, ], } ], @@ -212,8 +253,9 @@ def run_dse_qwen2_vl(client: OpenAI, model: str): model_example_map = { "clip": run_clip, - "vlm2vec": run_vlm2vec, "dse_qwen2_vl": run_dse_qwen2_vl, + "siglip": run_siglip, + "vlm2vec": run_vlm2vec, } diff --git a/examples/online_serving/prithvi_geospatial_mae.py b/examples/online_serving/prithvi_geospatial_mae.py index 611a7cbc89fa..a6246999c14d 100644 --- a/examples/online_serving/prithvi_geospatial_mae.py +++ b/examples/online_serving/prithvi_geospatial_mae.py @@ -11,14 +11,15 @@ # image as input, process it using the multimodal data processor, and # perform inference. # Requirements : -# - install plugin at: -# https://github.com/christian-pinto/prithvi_io_processor_plugin +# - install TerraTorch v1.1 (or later): +# pip install terratorch>=v1.1 # - start vllm in serving mode with the below args # --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM' # --model-impl terratorch # --task embed --trust-remote-code # --skip-tokenizer-init --enforce-eager -# --io-processor-plugin prithvi_to_tiff +# --io-processor-plugin terratorch_segmentation +# --enable-mm-embeds def main(): @@ -34,7 +35,6 @@ def main(): }, "priority": 0, "model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", - "softmax": False, } ret = requests.post(server_endpoint, json=request_payload_url) diff --git a/examples/online_serving/prometheus_grafana/grafana.json b/examples/online_serving/prometheus_grafana/grafana.json index 37abc9de926f..1c89d4593830 100644 --- a/examples/online_serving/prometheus_grafana/grafana.json +++ b/examples/online_serving/prometheus_grafana/grafana.json @@ -852,7 +852,7 @@ "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", - "expr": "vllm:gpu_cache_usage_perc{model_name=\"$model_name\"}", + "expr": "vllm:kv_cache_usage_perc{model_name=\"$model_name\"}", "instant": false, "legendFormat": "GPU Cache Usage", "range": true, diff --git a/examples/online_serving/structured_outputs/README.md b/examples/online_serving/structured_outputs/README.md index d2777a43d478..7f539716ecf8 100644 --- a/examples/online_serving/structured_outputs/README.md +++ b/examples/online_serving/structured_outputs/README.md @@ -21,7 +21,7 @@ If you want to run this script standalone with `uv`, you can use the following: ```bash uvx --from git+https://github.com/vllm-project/vllm#subdirectory=examples/online_serving/structured_outputs \ - structured-output + structured-outputs ``` See [feature docs](https://docs.vllm.ai/en/latest/features/structured_outputs.html) for more information. diff --git a/examples/online_serving/structured_outputs/pyproject.toml b/examples/online_serving/structured_outputs/pyproject.toml index 8f31405ff584..5e366ab0a03d 100644 --- a/examples/online_serving/structured_outputs/pyproject.toml +++ b/examples/online_serving/structured_outputs/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "examples-online-structured-outputs" -requires-python = ">=3.9, <3.13" +requires-python = ">=3.10, <3.14" dependencies = ["openai==1.78.1", "pydantic==2.11.4"] version = "0.0.0" diff --git a/examples/online_serving/structured_outputs/structured_outputs.py b/examples/online_serving/structured_outputs/structured_outputs.py index 3ea6c73e90e8..02853a95469a 100644 --- a/examples/online_serving/structured_outputs/structured_outputs.py +++ b/examples/online_serving/structured_outputs/structured_outputs.py @@ -1,21 +1,15 @@ # ruff: noqa: E501 # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from __future__ import annotations - import argparse import asyncio import enum import os -from typing import TYPE_CHECKING, Any, Literal +from typing import Any, Literal import openai import pydantic - -if TYPE_CHECKING: - from openai.types.chat import ChatCompletionChunk - +from openai.types.chat import ChatCompletionChunk ConstraintsFormat = Literal[ "choice", diff --git a/examples/others/tensorize_vllm_model.py b/examples/others/tensorize_vllm_model.py index acbfd8cda489..2601c9eff971 100644 --- a/examples/others/tensorize_vllm_model.py +++ b/examples/others/tensorize_vllm_model.py @@ -84,7 +84,7 @@ from vllm import LLM llm = LLM( "s3://my-bucket/vllm/facebook/opt-125m/v1", - load_format="tensorizer" + load_format="tensorizer", ) ``` diff --git a/pyproject.toml b/pyproject.toml index 704f28fa6536..29ee7f75f070 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ "packaging>=24.2", "setuptools>=77.0.3,<80.0.0", "setuptools-scm>=8.0", - "torch == 2.8.0", + "torch == 2.9.0", "wheel", "jinja2", ] @@ -20,7 +20,6 @@ license-files = ["LICENSE"] readme = "README.md" description = "A high-throughput and memory-efficient inference and serving engine for LLMs" classifiers = [ - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -31,7 +30,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Information Analysis", ] -requires-python = ">=3.9,<3.14" +requires-python = ">=3.10,<3.14" dynamic = [ "version", "dependencies", "optional-dependencies"] [project.urls] @@ -79,12 +78,12 @@ ignore = [ "F405", "F403", # lambda expression assignment "E731", + # zip without `strict=` + "B905", # Loop control variable not used within loop body "B007", # f-string format "UP032", - # Can remove once 3.10+ is the minimum Python version - "UP007", ] [tool.ruff.format] @@ -184,6 +183,7 @@ ba = "ba" [tool.typos.type.py.extend-words] ba = "ba" +nd = "nd" [tool.typos.type.cpp] extend-glob = ["*.cu"] diff --git a/requirements/build.txt b/requirements/build.txt index 5f826a1afa14..ba09eaab70e8 100644 --- a/requirements/build.txt +++ b/requirements/build.txt @@ -4,7 +4,7 @@ ninja packaging>=24.2 setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 -torch==2.8.0 +torch==2.9.0 wheel jinja2>=3.1.6 regex diff --git a/requirements/common.txt b/requirements/common.txt index 1530e5a09e75..81c4d6675006 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -7,13 +7,13 @@ requests >= 2.26.0 tqdm blake3 py-cpuinfo -transformers >= 4.55.2 +transformers >= 4.56.0 tokenizers >= 0.21.1 # Required for fast incremental detokenization. protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp openai >= 1.99.1 # For Responses API with reasoning content -pydantic >= 2.11.7 +pydantic >= 2.12.0 prometheus_client >= 0.18.0 pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 @@ -31,15 +31,14 @@ partial-json-parser # used for parsing partial JSON outputs pyzmq >= 25.0.0 msgspec gguf >= 0.13.0 -importlib_metadata; python_version < '3.10' -mistral_common[image,audio] >= 1.8.2 +mistral_common[image,audio] >= 1.8.5 opencv-python-headless >= 4.11.0 # required for video IO pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 setuptools>=77.0.3,<80; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. -compressed-tensors == 0.11.0 # required for compressed-tensors -depyf==0.19.0 # required for profiling and debugging with compilation config +compressed-tensors == 0.12.2 # required for compressed-tensors +depyf==0.20.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py watchfiles # required for http server to monitor the updates of TLS files python-json-logger # Used by logging as per examples/others/logging_configuration.md @@ -49,4 +48,4 @@ pybase64 # fast base64 implementation cbor2 # Required for cross-language serialization of hashable objects setproctitle # Used to set process names for better debugging and monitoring openai-harmony >= 0.0.3 # Required for gpt-oss -gpt-oss >= 0.0.7 +anthropic == 0.71.0 diff --git a/requirements/cpu-build.txt b/requirements/cpu-build.txt index b511b0f5d31b..bba7bc7a4d8c 100644 --- a/requirements/cpu-build.txt +++ b/requirements/cpu-build.txt @@ -6,6 +6,7 @@ setuptools-scm>=8 --extra-index-url https://download.pytorch.org/whl/cpu torch==2.8.0+cpu; platform_machine == "x86_64" torch==2.8.0; platform_machine == "ppc64le" or platform_machine == "aarch64" or platform_system == "Darwin" +scons; platform_machine == "aarch64" # needed to build Arm Compute Library (ACL) wheel jinja2>=3.1.6 regex diff --git a/requirements/cpu.txt b/requirements/cpu.txt index 2db6d87ee67b..d53ab3649308 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -1,8 +1,7 @@ # Common dependencies -r common.txt -numba == 0.60.0; python_version == '3.9' and platform_machine != "s390x" # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding -numba == 0.61.2; python_version > '3.9' and platform_machine != "s390x" +numba == 0.61.2; platform_machine != "s390x" # Required for N-gram speculative decoding # Dependencies for CPUs packaging>=24.2 diff --git a/requirements/cuda.txt b/requirements/cuda.txt index 3f8b8fca3209..7c5bc457d45b 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -1,14 +1,17 @@ # Common dependencies -r common.txt -numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding -numba == 0.61.2; python_version > '3.9' +numba == 0.61.2 # Required for N-gram speculative decoding # Dependencies for NVIDIA GPUs ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1. -torch==2.8.0 -torchaudio==2.8.0 +torch==2.9.0 +torchaudio==2.9.0 # These must be updated alongside torch -torchvision==0.23.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version +torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version # https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1 -xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8 +# xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8 +# FlashInfer should be updated together with the Dockerfile +flashinfer-python==0.4.1 +# Triton Kernels are needed for mxfp4 fused moe. (Should be updated alongside torch) +triton_kernels @ git+https://github.com/triton-lang/triton.git@v3.5.0#subdirectory=python/triton_kernels diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index 33f1bc04ea90..dea1926bbd69 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -23,7 +23,7 @@ jiwer # required for audio tests timm # required for internvl test transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test -mistral_common[image,audio] >= 1.8.2 # required for voxtral test +mistral_common[image,audio] >= 1.8.5 # required for voxtral test num2words # required for smolvlm test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test @@ -40,9 +40,8 @@ buildkite-test-collector==0.1.9 genai_perf==0.0.8 tritonclient==2.51.0 -numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding -numba == 0.61.2; python_version > '3.9' +numba == 0.61.2 # Required for N-gram speculative decoding numpy runai-model-streamer[s3,gcs]==0.14.0 fastsafetensors>=0.1.10 -pydantic>=2.10 # 2.9 leads to error on python 3.10 +pydantic>=2.12 # 2.11 leads to error on python 3.13 diff --git a/requirements/rocm-build.txt b/requirements/rocm-build.txt index a86a8ab6df14..51f58e57a785 100644 --- a/requirements/rocm-build.txt +++ b/requirements/rocm-build.txt @@ -1,12 +1,12 @@ # Common dependencies -r common.txt ---extra-index-url https://download.pytorch.org/whl/rocm6.3 -torch==2.8.0 -torchvision==0.23.0 -torchaudio==2.8.0 +--extra-index-url https://download.pytorch.org/whl/rocm6.4 +torch==2.9.0 +torchvision==0.24.0 +torchaudio==2.9.0 -triton==3.3.0 +triton==3.5.0 cmake>=3.26.1,<4 packaging>=24.2 setuptools>=77.0.3,<80.0.0 diff --git a/requirements/rocm-test.txt b/requirements/rocm-test.txt index 869fb28c3d85..541fa1e267cb 100644 --- a/requirements/rocm-test.txt +++ b/requirements/rocm-test.txt @@ -1,6 +1,8 @@ # Common dependencies -r common.txt tblib==3.1.0 +bm25s==0.2.13 +pystemmer==3.0.0 # entrypoints test # librosa==0.10.2.post1 # required by audio tests in entrypoints/openai @@ -29,4 +31,8 @@ matplotlib==3.10.3 # Multi-Modal Models Test (Extended) 3 blobfile==3.0.0 +# Required for openai schema test. +schemathesis==3.39.15 +# required for mteb test +mteb[bm25s]>=1.38.11, <2 diff --git a/requirements/rocm.txt b/requirements/rocm.txt index 9077085f2621..d9743f044643 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -1,8 +1,7 @@ # Common dependencies -r common.txt -numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding -numba == 0.61.2; python_version > '3.9' +numba == 0.61.2 # Required for N-gram speculative decoding # Dependencies for AMD GPUs datasets diff --git a/requirements/test.in b/requirements/test.in index ef21d6db5b4f..a79ec839dbec 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -24,12 +24,12 @@ soundfile # required for audio tests jiwer # required for audio tests tblib # for pickling test exceptions timm >=1.0.17 # required for internvl and gemma3n-mm test -torch==2.8.0 -torchaudio==2.8.0 -torchvision==0.23.0 +torch==2.9.0 +torchaudio==2.9.0 +torchvision==0.24.0 transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test -mistral_common[image,audio] >= 1.8.2 # required for voxtral test +mistral_common[image,audio] >= 1.8.5 # required for voxtral test num2words # required for smolvlm test open_clip_torch==2.32.0 # Required for nemotron_vl test opencv-python-headless >= 4.11.0 # required for video test @@ -48,11 +48,11 @@ buildkite-test-collector==0.1.9 genai_perf==0.0.8 tritonclient==2.51.0 -numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding -numba == 0.61.2; python_version > '3.9' +numba == 0.61.2 # Required for N-gram speculative decoding numpy runai-model-streamer[s3,gcs]==0.14.0 fastsafetensors>=0.1.10 -pydantic>=2.10 # 2.9 leads to error on python 3.10 +pydantic>=2.12 # 2.11 leads to error on python 3.13 decord==0.6.0 terratorch @ git+https://github.com/IBM/terratorch.git@1.1.rc3 # required for PrithviMAE test +gpt-oss >= 0.0.7; python_version > '3.11' diff --git a/requirements/test.txt b/requirements/test.txt index 9cab85ce0ef6..bc007ccf10bb 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,5 +1,5 @@ # This file was autogenerated by uv via the following command: -# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu128 --python-platform x86_64-manylinux_2_28 +# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu129 --python-platform x86_64-manylinux_2_28 absl-py==2.1.0 # via rouge-score accelerate==1.0.1 @@ -10,18 +10,19 @@ aenum==3.1.16 # via lightly affine==2.4.0 # via rasterio -aiohappyeyeballs==2.4.3 +aiohappyeyeballs==2.6.1 # via aiohttp -aiohttp==3.10.11 +aiohttp==3.13.0 # via # aiohttp-cors # datasets # fsspec + # gpt-oss # lm-eval # ray aiohttp-cors==0.8.1 # via ray -aiosignal==1.3.1 +aiosignal==1.4.0 # via aiohttp albucore==0.0.16 # via terratorch @@ -103,6 +104,8 @@ chardet==5.2.0 # via mbstrdecoder charset-normalizer==3.4.0 # via requests +chz==0.3.0 + # via gpt-oss click==8.1.7 # via # black @@ -173,7 +176,9 @@ distlib==0.3.9 dnspython==2.7.0 # via email-validator docker==7.1.0 - # via mlflow + # via + # gpt-oss + # mlflow docopt==0.6.2 # via num2words docstring-parser==0.17.0 @@ -199,7 +204,9 @@ eval-type-backport==0.2.2 evaluate==0.4.3 # via lm-eval fastapi==0.116.1 - # via mlflow-skinny + # via + # gpt-oss + # mlflow-skinny fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -274,6 +281,8 @@ google-resumable-media==2.7.2 # via google-cloud-storage googleapis-common-protos==1.70.0 # via google-api-core +gpt-oss==0.0.8 + # via -r requirements/test.in graphene==3.4.3 # via mlflow graphql-core==3.2.6 @@ -301,6 +310,8 @@ hf-xet==1.1.7 # via huggingface-hub hiredis==3.0.0 # via tensorizer +html2text==2025.4.15 + # via gpt-oss httpcore==1.0.6 # via httpx httpx==0.27.2 @@ -435,6 +446,7 @@ lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b772215 lxml==5.3.0 # via # blobfile + # gpt-oss # sacrebleu mako==1.3.10 # via alembic @@ -462,7 +474,7 @@ mbstrdecoder==1.1.3 # typepy mdurl==0.1.2 # via markdown-it-py -mistral-common==1.8.2 +mistral-common==1.8.5 # via -r requirements/test.in mlflow==2.22.0 # via terratorch @@ -561,42 +573,44 @@ numpy==1.26.4 # tritonclient # vocos # xarray -nvidia-cublas-cu12==12.8.4.1 +nvidia-cublas-cu12==12.9.1.4 # via # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 # torch -nvidia-cuda-cupti-cu12==12.8.90 +nvidia-cuda-cupti-cu12==12.9.79 # via torch -nvidia-cuda-nvrtc-cu12==12.8.93 +nvidia-cuda-nvrtc-cu12==12.9.86 # via torch -nvidia-cuda-runtime-cu12==12.8.90 +nvidia-cuda-runtime-cu12==12.9.79 # via torch nvidia-cudnn-cu12==9.10.2.21 # via torch -nvidia-cufft-cu12==11.3.3.83 +nvidia-cufft-cu12==11.4.1.4 # via torch -nvidia-cufile-cu12==1.13.1.3 +nvidia-cufile-cu12==1.14.1.1 # via torch -nvidia-curand-cu12==10.3.9.90 +nvidia-curand-cu12==10.3.10.19 # via torch -nvidia-cusolver-cu12==11.7.3.90 +nvidia-cusolver-cu12==11.7.5.82 # via torch -nvidia-cusparse-cu12==12.5.8.93 +nvidia-cusparse-cu12==12.5.10.65 # via # nvidia-cusolver-cu12 # torch nvidia-cusparselt-cu12==0.7.1 # via torch -nvidia-nccl-cu12==2.27.3 +nvidia-nccl-cu12==2.27.5 # via torch -nvidia-nvjitlink-cu12==12.8.93 +nvidia-nvjitlink-cu12==12.9.86 # via # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 # torch -nvidia-nvtx-cu12==12.8.90 +nvidia-nvshmem-cu12==3.3.20 + # via torch +nvidia-nvtx-cu12==12.9.79 # via torch omegaconf==2.3.0 # via @@ -604,6 +618,8 @@ omegaconf==2.3.0 # lightning open-clip-torch==2.32.0 # via -r requirements/test.in +openai-harmony==0.0.4 + # via gpt-oss opencensus==0.11.4 # via ray opencensus-context==0.1.3 @@ -724,7 +740,9 @@ prometheus-client==0.22.0 # opentelemetry-exporter-prometheus # ray propcache==0.2.0 - # via yarl + # via + # aiohttp + # yarl proto-plus==1.26.1 # via google-api-core protobuf==5.28.3 @@ -767,19 +785,21 @@ pycparser==2.22 # via cffi pycryptodomex==3.22.0 # via blobfile -pydantic==2.11.7 +pydantic==2.12.0 # via # -r requirements/test.in # albumentations # datamodel-code-generator # fastapi + # gpt-oss # lightly # mistral-common # mlflow-skinny # mteb + # openai-harmony # pydantic-extra-types # ray -pydantic-core==2.33.2 +pydantic-core==2.41.1 # via pydantic pydantic-extra-types==2.10.5 # via mistral-common @@ -907,6 +927,7 @@ requests==2.32.3 # evaluate # google-api-core # google-cloud-storage + # gpt-oss # huggingface-hub # lightly # lm-eval @@ -993,14 +1014,11 @@ sentence-transformers==3.2.1 # via # -r requirements/test.in # mteb -sentencepiece==0.2.0 - # via mistral-common setuptools==77.0.3 # via # lightning-utilities # pytablewriter # torch - # triton shapely==2.1.1 # via # geopandas @@ -1052,6 +1070,8 @@ starlette-testclient==0.4.1 # via schemathesis statsmodels==0.14.4 # via genai-perf +structlog==25.4.0 + # via gpt-oss sympy==1.13.3 # via # einx @@ -1064,14 +1084,17 @@ tblib==3.1.0 # via -r requirements/test.in tcolorpy==0.1.6 # via pytablewriter -tenacity==9.0.0 +tenacity==9.1.2 # via + # gpt-oss # lm-eval # plotly tensorboardx==2.6.4 # via lightning tensorizer==2.10.1 # via -r requirements/test.in +termcolor==3.1.0 + # via gpt-oss terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e # via -r requirements/test.in threadpoolctl==3.5.0 @@ -1080,8 +1103,9 @@ tifffile==2025.3.30 # via # scikit-image # terratorch -tiktoken==0.7.0 +tiktoken==0.12.0 # via + # gpt-oss # lm-eval # mistral-common timm==1.0.17 @@ -1099,7 +1123,7 @@ tomli==2.2.1 # via schemathesis tomli-w==1.2.0 # via schemathesis -torch==2.8.0+cu128 +torch==2.9.0+cu129 # via # -r requirements/test.in # accelerate @@ -1128,7 +1152,7 @@ torch==2.8.0+cu128 # torchvision # vector-quantize-pytorch # vocos -torchaudio==2.8.0+cu128 +torchaudio==2.9.0+cu129 # via # -r requirements/test.in # encodec @@ -1141,7 +1165,7 @@ torchmetrics==1.7.4 # pytorch-lightning # terratorch # torchgeo -torchvision==0.23.0+cu128 +torchvision==0.24.0+cu129 # via # -r requirements/test.in # lightly @@ -1182,7 +1206,7 @@ transformers==4.56.2 # transformers-stream-generator transformers-stream-generator==0.0.5 # via -r requirements/test.in -triton==3.4.0 +triton==3.5.0 # via torch tritonclient==2.51.0 # via @@ -1199,10 +1223,12 @@ types-python-dateutil==2.9.0.20241206 # via arrow typeshed-client==2.8.2 # via jsonargparse -typing-extensions==4.12.2 +typing-extensions==4.15.0 # via + # aiosignal # albumentations # alembic + # chz # fastapi # graphene # huggingface-hub @@ -1226,7 +1252,7 @@ typing-extensions==4.12.2 # typer # typeshed-client # typing-inspection -typing-inspection==0.4.1 +typing-inspection==0.4.2 # via pydantic tzdata==2024.2 # via pandas @@ -1242,7 +1268,9 @@ urllib3==2.2.3 # responses # tritonclient uvicorn==0.35.0 - # via mlflow-skinny + # via + # gpt-oss + # mlflow-skinny vector-quantize-pytorch==1.21.2 # via -r requirements/test.in virtualenv==20.31.2 diff --git a/requirements/xpu.txt b/requirements/xpu.txt index 74f5b05b2382..d14b631aa936 100644 --- a/requirements/xpu.txt +++ b/requirements/xpu.txt @@ -9,8 +9,7 @@ setuptools>=77.0.3,<80.0.0 wheel jinja2>=3.1.6 datasets # for benchmark scripts -numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding -nixl==0.3.0 # for PD disaggregation +numba == 0.61.2 # Required for N-gram speculative decoding torch==2.8.0+xpu torchaudio torchvision diff --git a/setup.py b/setup.py index 53c460d2c5b8..990fe4cde3ca 100644 --- a/setup.py +++ b/setup.py @@ -540,6 +540,11 @@ def get_gaudi_sw_version(): def get_vllm_version() -> str: + # Allow overriding the version. This is useful to build platform-specific + # wheels (e.g. CPU, TPU) without modifying the source. + if env_version := os.getenv("VLLM_VERSION_OVERRIDE"): + return env_version + version = get_version(write_to="vllm/_version.py") sep = "+" if "+" not in version else "." # dev versions might contain + @@ -714,8 +719,7 @@ def _read_requirements(filename: str) -> list[str]: "mistral_common[audio]", ], # Required for audio processing "video": [], # Kept for backwards compatibility - # FlashInfer should be updated together with the Dockerfile - "flashinfer": ["flashinfer-python==0.3.1"], + "flashinfer": [], # Kept for backwards compatibility # Optional deps for AMD FP4 quantization support "petit-kernel": ["petit-kernel"], }, diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 9b9d8cfea7fa..0cf1e85d4e8e 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -20,7 +20,7 @@ from ..utils import multi_gpu_test MODELS = [ - "google/gemma-2-2b-it", + "hmellor/tiny-random-Gemma2ForCausalLM", "meta-llama/Llama-3.2-1B-Instruct", ] @@ -29,7 +29,7 @@ def test_vllm_gc_ed(): """Verify vllm instance is GC'ed when it is deleted""" - llm = LLM("distilbert/distilgpt2") + llm = LLM("hmellor/tiny-random-LlamaForCausalLM") weak_llm = weakref.ref(llm) del llm # If there's any circular reference to vllm, this fails @@ -125,14 +125,14 @@ def test_models( @pytest.mark.parametrize( "model, distributed_executor_backend, attention_backend, test_suite, extra_env", [ - ("distilbert/distilgpt2", "ray", "", "L4", {}), - ("distilbert/distilgpt2", "mp", "", "L4", {}), - ("distilbert/distilgpt2", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}), - ("distilbert/distilgpt2", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}), + ("facebook/opt-125m", "ray", "", "L4", {}), + ("facebook/opt-125m", "mp", "", "L4", {}), + ("facebook/opt-125m", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}), + ("facebook/opt-125m", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}), ("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}), ("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}), - ("distilbert/distilgpt2", "ray", "", "A100", {}), - ("distilbert/distilgpt2", "mp", "", "A100", {}), + ("facebook/opt-125m", "ray", "", "A100", {}), + ("facebook/opt-125m", "mp", "", "A100", {}), ], ) @pytest.mark.parametrize("enable_prompt_embeds", [True, False]) @@ -157,11 +157,9 @@ def test_models_distributed( and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4" + and enable_prompt_embeds ): # noqa - if enable_prompt_embeds: - pytest.skip("enable_prompt_embeds does not work with ray compiled dag.") - monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1") - monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1") + pytest.skip("enable_prompt_embeds does not work with ray compiled dag.") if attention_backend: monkeypatch_context.setenv( diff --git a/tests/basic_correctness/test_cpu_offload.py b/tests/basic_correctness/test_cpu_offload.py index 3c1e01d072b9..89839372c309 100644 --- a/tests/basic_correctness/test_cpu_offload.py +++ b/tests/basic_correctness/test_cpu_offload.py @@ -6,5 +6,5 @@ def test_cpu_offload(): compare_two_settings( - "meta-llama/Llama-3.2-1B-Instruct", [], ["--cpu-offload-gb", "1"] + "hmellor/tiny-random-LlamaForCausalLM", [], ["--cpu-offload-gb", "1"] ) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index f1b0f7b2de89..09f4ec03fbbb 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -6,7 +6,7 @@ from vllm import LLM, SamplingParams from vllm.device_allocator.cumem import CuMemAllocator -from vllm.utils import GiB_bytes +from vllm.utils.mem_constants import GiB_bytes from ..utils import create_new_process_for_each_test @@ -120,7 +120,7 @@ def model(x): "model", [ # sleep mode with safetensors - "meta-llama/Llama-3.2-1B", + "hmellor/tiny-random-LlamaForCausalLM", # sleep mode with pytorch checkpoint "facebook/opt-125m", ], @@ -174,7 +174,7 @@ def test_end_to_end(model: str): @create_new_process_for_each_test() def test_deep_sleep(): - model = "Qwen/Qwen3-0.6B" + model = "hmellor/tiny-random-LlamaForCausalLM" free, total = torch.cuda.mem_get_info() used_bytes_baseline = total - free # in case other process is running llm = LLM(model, enable_sleep_mode=True) diff --git a/tests/benchmarks/test_random_dataset.py b/tests/benchmarks/test_random_dataset.py index 90527dbeae28..68e4afdcbe52 100644 --- a/tests/benchmarks/test_random_dataset.py +++ b/tests/benchmarks/test_random_dataset.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random -from typing import Any, NamedTuple, Optional, cast +from typing import Any, NamedTuple, cast import numpy as np import pytest @@ -185,8 +185,8 @@ def _collect_mm_samples( output_len: int = 5, base_items_per_request: int = 2, num_mm_items_range_ratio: float = 0.0, - limit_mm_per_prompt: Optional[dict[str, int]] = None, - bucket_config: Optional[dict[tuple[int, int, int], float]] = None, + limit_mm_per_prompt: dict[str, int] | None = None, + bucket_config: dict[tuple[int, int, int], float] | None = None, enable_multimodal_chat: bool = False, ) -> list[SampleRequest]: if limit_mm_per_prompt is None: diff --git a/tests/ci_envs.py b/tests/ci_envs.py index d16ecce1ef8d..f3a54f308cd8 100644 --- a/tests/ci_envs.py +++ b/tests/ci_envs.py @@ -5,13 +5,16 @@ """ import os -from typing import TYPE_CHECKING, Any, Callable, Optional +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +from vllm.envs import maybe_convert_bool if TYPE_CHECKING: VLLM_CI_NO_SKIP: bool = False - VLLM_CI_DTYPE: Optional[str] = None - VLLM_CI_HEAD_DTYPE: Optional[str] = None - VLLM_CI_HF_DTYPE: Optional[str] = None + VLLM_CI_DTYPE: str | None = None + VLLM_CI_HEAD_DTYPE: str | None = None + VLLM_CI_HF_DTYPE: str | None = None environment_variables: dict[str, Callable[[], Any]] = { # A model family has many models with the same architecture. @@ -24,6 +27,10 @@ "VLLM_CI_HEAD_DTYPE": lambda: os.getenv("VLLM_CI_HEAD_DTYPE", None), # Allow changing the head dtype used by transformers in tests "VLLM_CI_HF_DTYPE": lambda: os.getenv("VLLM_CI_HF_DTYPE", None), + # Allow control over whether tests use enforce_eager + "VLLM_CI_ENFORCE_EAGER": lambda: maybe_convert_bool( + os.getenv("VLLM_CI_ENFORCE_EAGER", None) + ), } diff --git a/tests/compile/backend.py b/tests/compile/backend.py index 36bc832a1329..fa426190067f 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -2,18 +2,23 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import weakref -from collections.abc import Sequence +from collections.abc import Callable, Sequence +from contextlib import nullcontext from copy import deepcopy -from typing import Callable, Union +import depyf from torch import fx from torch._ops import OpOverload +from torch.fx._utils import lazy_format_graph_code from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.inductor_pass import InductorPass from vllm.compilation.pass_manager import with_pattern_match_debug from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import VllmConfig, get_current_vllm_config +from vllm.logger import init_logger + +logger = init_logger("vllm.tests.compile.backend") class LazyInitPass(InductorPass): @@ -44,22 +49,34 @@ class TestBackend: Inductor config is default-initialized from VllmConfig.CompilationConfig. """ - def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]): + def __init__(self, *passes: InductorPass | Callable[[fx.Graph], None]): self.custom_passes = list(passes) - compile_config = get_current_vllm_config().compilation_config - self.inductor_config = compile_config.inductor_compile_config + vllm_config = get_current_vllm_config() + compile_config = vllm_config.compilation_config + # Deepcopy to allow multiple TestBackend instances to use the same VllmConfig + self.inductor_config = deepcopy(compile_config.inductor_compile_config) self.inductor_config["force_disable_caches"] = True self.inductor_config["post_grad_custom_post_pass"] = self.post_pass + if debug_dump_path := vllm_config.compile_debug_dump_path(): + logger.debug("Dumping depyf output to %s", debug_dump_path) + self.debug_ctx = depyf.prepare_debug(debug_dump_path.as_posix()) + else: + self.debug_ctx = nullcontext() + def __call__(self, graph: fx.GraphModule, example_inputs): self.graph_pre_compile = deepcopy(graph) from torch._inductor.compile_fx import compile_fx - return compile_fx(graph, example_inputs, config_patches=self.inductor_config) + with self.debug_ctx: + return compile_fx( + graph, example_inputs, config_patches=self.inductor_config + ) @with_pattern_match_debug def post_pass(self, graph: fx.Graph): self.graph_pre_pass = deepcopy(graph) + lazy_format_graph_code("graph_pre_pass", graph.owning_module) VllmInductorPass.dump_prefix = 0 for pass_ in self.custom_passes: @@ -69,6 +86,7 @@ def post_pass(self, graph: fx.Graph): VllmInductorPass.dump_prefix = None self.graph_post_pass = deepcopy(graph) + lazy_format_graph_code("graph_post_pass", graph.owning_module) # assign by reference, will reflect the final state of the graph self.final_graph = graph diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index 84194f3ed01e..c6d4b5272dbc 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -11,6 +11,7 @@ from vllm import LLM, SamplingParams from vllm.config import CompilationConfig from vllm.platforms import current_platform +from vllm.utils.torch_utils import is_torch_equal_or_newer @contextlib.contextmanager @@ -32,13 +33,13 @@ def temporary_environ(env_vars): os.environ[k] = v -test_params_full_cudagraph = [] +model_backends_full_cudagraph = [] # deepseek-ai/DeepSeek-V2-Lite with MLA MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"] for mla_backend in MLA_backends: - test_params_full_cudagraph.append( - pytest.param(("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])) + model_backends_full_cudagraph.append( + ("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend]) ) # Qwen/Qwen2-1.5B-Instruct with other backends @@ -46,14 +47,18 @@ def temporary_environ(env_vars): backend_configs[c] for c in backend_configs if c not in MLA_backends ] for backend_config in other_backend_configs: - test_params_full_cudagraph.append( - pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)) - ) + model_backends_full_cudagraph.append(("Qwen/Qwen2-1.5B-Instruct", backend_config)) @pytest.fixture(scope="class") def llm_pair(request): - model, backend_config = request.param + model, backend_config, use_inductor_graph_partition = request.param + backend_config.comp_config["use_inductor_graph_partition"] = ( + use_inductor_graph_partition + ) + + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition only supported in torch>=2.9") # Dynamically skip test if GPU capability is not met if ( @@ -104,7 +109,15 @@ def llm_pair(request): ) -@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True) +@pytest.mark.parametrize( + "llm_pair", + [ + pytest.param((model, backend_config, use_inductor_graph_partition)) + for model, backend_config in model_backends_full_cudagraph + for use_inductor_graph_partition in [True, False] + ], + indirect=True, +) class TestFullCUDAGraph: """ Use a class such that an llm pair is constructed once for all diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index 7372dc99bc79..700f57ffb068 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -5,6 +5,7 @@ are compiled and graph captured separately. """ +import pytest import torch from torch import nn @@ -13,12 +14,13 @@ from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile from vllm.config import ( CompilationConfig, - CompilationLevel, + CompilationMode, CUDAGraphMode, VllmConfig, set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.utils.torch_utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from .. import silly_attention # noqa: F401 @@ -190,16 +192,21 @@ def run_model( return output.cpu() -def test_multi_graph_piecewise_compile_outputs_equal(): +@pytest.mark.parametrize("use_inductor_graph_partition", [False, True]) +def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool): + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + outputs = [] - # piecewise compile + # vllmcompile compile vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, use_cudagraph=True, - splitting_ops=["silly.attention"], + splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], + use_inductor_graph_partition=use_inductor_graph_partition, ) ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE @@ -220,23 +227,31 @@ def test_multi_graph_piecewise_compile_outputs_equal(): # static tensor addresses inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda() - with compilation_counter.expect( - num_graphs_seen=2, # two graphs for the model - num_piecewise_graphs_seen=6, + if use_inductor_graph_partition: + # Splitting happens at Inductor lowering level, + # total piecewise fx graphs is equal to total graphs + num_piecewise_fx = 2 + num_piecewise_capturable_fx = 2 + else: # attn_one, attn_two each has 3 piecewise graphs # (pre attn, post attn, silly_attention) each - num_piecewise_capturable_graphs_seen=4, + num_piecewise_fx = 6 # attn_one, attn_two has pre attn and post attn each, total=4 - num_backend_compilations=4, # num_piecewise_capturable_graphs_seen - num_cudagraph_captured=8, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_piecewise_capturable_fx = 4 + + with compilation_counter.expect( + num_graphs_seen=2, # two graphs for the model + num_piecewise_graphs_seen=num_piecewise_fx, + num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx, + num_backend_compilations=num_piecewise_capturable_fx, + num_cudagraph_captured=8, # num_cudagraph_sizes * num_partitions ): outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # no compile or cudagraph vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.NO_COMPILATION, + mode=CompilationMode.NONE, ) ) cudagraph_runtime_mode = CUDAGraphMode.NONE @@ -265,9 +280,10 @@ def test_multi_graph_piecewise_compile_outputs_equal(): # piecewise compile without CUDA graph vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, use_cudagraph=False, - splitting_ops=["silly.attention"], + splitting_ops=["silly::attention"], + use_inductor_graph_partition=use_inductor_graph_partition, ) ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE @@ -286,9 +302,9 @@ def test_multi_graph_piecewise_compile_outputs_equal(): with compilation_counter.expect( num_graphs_seen=2, - num_piecewise_graphs_seen=6, - num_piecewise_capturable_graphs_seen=4, - num_backend_compilations=4, + num_piecewise_graphs_seen=num_piecewise_fx, + num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx, + num_backend_compilations=num_piecewise_capturable_fx, num_cudagraph_captured=0, # no cudagraph captured ): outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 8241d248fa53..228859532ef4 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -13,13 +13,13 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import ( CompilationConfig, - CompilationLevel, + CompilationMode, CUDAGraphMode, VllmConfig, set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from ..silly_attention import get_global_counter, reset_global_counter @@ -61,7 +61,7 @@ def _run_simple_model( ): vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, use_cudagraph=True, use_inductor=use_inductor, splitting_ops=splitting_ops, @@ -127,7 +127,7 @@ def _run_simple_model( @torch.inference_mode() def test_simple_piecewise_compile(use_inductor): _run_simple_model( - splitting_ops=["silly.attention"], + splitting_ops=["silly::attention"], use_inductor_graph_partition=False, use_inductor=use_inductor, # 2 * num_layers + 1 @@ -142,14 +142,16 @@ def test_simple_piecewise_compile(use_inductor): @torch.inference_mode() -@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []]) -def test_simple_inductor_graph_partition(splitting_ops): +def test_simple_inductor_graph_partition(monkeypatch): if not is_torch_equal_or_newer("2.9.0.dev"): pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + # disable compile cache so that we run separately for different splitting_ops + # and get the expected number of cudagraphs captured. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + _run_simple_model( - # Inductor graph partition automatically resets splitting_ops to an empty list - splitting_ops=splitting_ops, + splitting_ops=["silly::attention"], use_inductor_graph_partition=True, use_inductor=True, # Since not splitting at fx graph level diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index c3aff8ddad49..6887673eb6a5 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -9,8 +9,9 @@ initialized randomly with a fixed seed. """ +from copy import deepcopy from dataclasses import dataclass -from typing import Any, Optional +from typing import Any import pytest import torch @@ -20,12 +21,13 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import ( CompilationConfig, - CompilationLevel, + CompilationMode, CUDAGraphMode, VllmConfig, set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.utils.torch_utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from .. import silly_attention # noqa: F401 @@ -162,7 +164,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: """ For tractable computation: @@ -217,7 +219,7 @@ def __init__( def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, ) -> torch.Tensor: hidden_states = self.embedding_tokens(input_ids) @@ -257,27 +259,13 @@ def tractable_computation( @torch.inference_mode -def run_model( - llama_config, use_compile: bool, backend: str, split_attn: bool = False -) -> torch.Tensor: - if use_compile: - compilation_config = CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - backend=backend, - cudagraph_capture_sizes=[1, 2], - ) - if split_attn: - compilation_config.splitting_ops = ["silly.attention"] - cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE - else: - compilation_config = CompilationConfig( - level=CompilationLevel.NO_COMPILATION, - ) - cudagraph_runtime_mode = CUDAGraphMode.NONE +def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor: + # Start with a fresh copy to make sure there's no cache dir sharing + compile_config = deepcopy(compile_config) + cudagraph_runtime_mode = compile_config.cudagraph_mode vllm_config = VllmConfig( - compilation_config=compilation_config, additional_config=llama_config + compilation_config=compile_config, additional_config=llama_config ) with set_current_vllm_config(vllm_config): model = ( @@ -338,8 +326,24 @@ def run_model( return output.cpu() -@pytest.mark.parametrize("backend", ["inductor", "eager"]) -def test_toy_llama(backend: str): +@pytest.mark.parametrize( + "backend, use_inductor_graph_partition", + [ + ("eager", False), # No inductor + ("inductor", False), # Inductor, Dynamo partition + ("inductor", True), # Inductor, Inductor partition + ], +) +def test_toy_llama( + backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path +): + # We disable the vLLM compile cache into a new tmp dir for 1 reason: + # 1. To make sure we can properly track the number of Inductor compilations. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition only supported in torch>=2.9") + # compare output with and without piecewise compilation llama_config = LlamaConfig( @@ -350,6 +354,23 @@ def test_toy_llama(backend: str): hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True ) + compile_config_no_compile = CompilationConfig( + mode=CompilationMode.NONE, + cudagraph_mode=CUDAGraphMode.NONE, + backend="eager", + ) + + compile_config_no_split = CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + use_inductor_graph_partition=use_inductor_graph_partition, + cudagraph_mode=CUDAGraphMode.PIECEWISE, + backend=backend, + cudagraph_capture_sizes=[1, 2], + ) + + compile_config_split = deepcopy(compile_config_no_split) + compile_config_split.splitting_ops = ["silly::attention"] + outputs = [] with compilation_counter.expect( num_graphs_seen=0, @@ -358,8 +379,9 @@ def test_toy_llama(backend: str): num_backend_compilations=0, num_cudagraph_captured=0, ): - outputs.append(run_model(llama_config, backend="eager", use_compile=False)) - run_model(tractable_config, backend="eager", use_compile=False) + outputs.append(run_model(llama_config, compile_config_no_compile)) + + run_model(tractable_config, compile_config_no_compile) if backend == "inductor": kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0} @@ -367,35 +389,34 @@ def test_toy_llama(backend: str): kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0} with compilation_counter.expect( - # One graph for the model - num_graphs_seen=1, + num_graphs_seen=1, # one graph for the model num_piecewise_graphs_seen=1, num_piecewise_capturable_graphs_seen=1, - # num_piecewise_capturable_graphs_seen - num_backend_compilations=1, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_backend_compilations=1, # num_piecewise_capturable_graphs_seen num_cudagraph_captured=2, **kwargs, ): - outputs.append(run_model(llama_config, backend=backend, use_compile=True)) - run_model(tractable_config, backend=backend, use_compile=True) + outputs.append(run_model(llama_config, compile_config_no_split)) + + run_model(tractable_config, compile_config_no_split) + + if use_inductor_graph_partition: + num_piecewise_fx = 1 + num_piecewise_capturable_fx = 1 + else: + num_piecewise_fx = 2 * llama_config.num_layers + 1 + num_piecewise_capturable_fx = 1 + llama_config.num_layers with compilation_counter.expect( num_graphs_seen=1, # one graph for the model - num_piecewise_graphs_seen=2 * llama_config.num_layers + 1, # 2 * num_layers + 1 - num_piecewise_capturable_graphs_seen=1 - + llama_config.num_layers, # 1 + num_layers - num_backend_compilations=1 - + llama_config.num_layers, # num_piecewise_capturable_graphs_seen - num_cudagraph_captured=2 - * ( - 1 + llama_config.num_layers - ), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_piecewise_graphs_seen=num_piecewise_fx, + num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx, + num_backend_compilations=num_piecewise_capturable_fx, + # num_cudagraph_sizes * num_partitions + num_cudagraph_captured=2 * (1 + llama_config.num_layers), ): - outputs.append( - run_model(llama_config, backend=backend, use_compile=True, split_attn=True) - ) - run_model(tractable_config, backend=backend, use_compile=True, split_attn=True) + outputs.append(run_model(llama_config, compile_config_split)) + run_model(tractable_config, compile_config_split) for i in range(1, len(outputs)): assert torch.allclose(outputs[0], outputs[i]) @@ -427,14 +448,14 @@ def benchmark(): for piecewise in [False, True]: if piecewise: compilation_config = CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, use_cudagraph=True, - splitting_ops=["silly.attention"], + splitting_ops=["silly::attention"], cudagraph_capture_sizes=cudagraph_sizes, ) else: compilation_config = CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, cudagraph_capture_sizes=cudagraph_sizes, ) diff --git a/tests/compile/silly_attention.py b/tests/compile/silly_attention.py index c0d3f908149f..29c02f6e6a1d 100644 --- a/tests/compile/silly_attention.py +++ b/tests/compile/silly_attention.py @@ -8,7 +8,7 @@ import torch from torch.library import Library -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op # Shared library for all compilation test operations # Using "silly" namespace to match existing test expectations @@ -62,5 +62,4 @@ def silly_attention_fake( mutates_args=["out"], fake_impl=silly_attention_fake, target_lib=silly_lib, - tags=(torch._C.Tag.cudagraph_unsafe,), ) diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py new file mode 100644 index 000000000000..c65e5a25934d --- /dev/null +++ b/tests/compile/test_aot_compile.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import tempfile +from contextlib import contextmanager + +import pytest +import torch + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import ( + CompilationConfig, + CompilationMode, + VllmConfig, + set_current_vllm_config, +) +from vllm.forward_context import set_forward_context +from vllm.utils.torch_utils import is_torch_equal_or_newer + + +def reference_fn(x: torch.Tensor): + assert x.shape[0] <= 42 + assert x.shape[0] % 2 == 0 + for _ in range(3000): + x = x + x.shape[0] + return x + + +@support_torch_compile +class CompiledMod(torch.nn.Module): + def __init__(self, **kwargs): + super().__init__() + + def forward(self, x: torch.Tensor): + return reference_fn(x) + + +def make_vllm_config() -> VllmConfig: + return VllmConfig( + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + ) + ) + + +@contextmanager +def use_vllm_config(vllm_config: VllmConfig): + with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config): + yield + + +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) +def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m: + vllm_config = make_vllm_config() + args = (torch.randn(10, 10),) + expected = reference_fn(*args) + with use_vllm_config(vllm_config): + m.setenv("VLLM_USE_AOT_COMPILE", "0") + with ( + pytest.raises(RuntimeError, match="Detected recompile"), + torch.compiler.set_stance("fail_on_recompile"), + ): + CompiledMod(vllm_config=vllm_config)(*args) + + m.setenv("VLLM_USE_AOT_COMPILE", "1") + torch._dynamo.reset() + with torch.compiler.set_stance("fail_on_recompile"): + actual = CompiledMod(vllm_config=vllm_config)(*args) + assert torch.allclose(actual, expected) + + +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) +def test_force_aot_load(monkeypatch: pytest.MonkeyPatch): + with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m: + args = (torch.randn(10, 10),) + m.setenv("VLLM_USE_AOT_COMPILE", "1") + m.setenv("VLLM_FORCE_AOT_LOAD", "1") + m.setenv("VLLM_CACHE_ROOT", tmpdirname) + vllm_config = make_vllm_config() + with use_vllm_config(vllm_config), pytest.raises(FileNotFoundError): + CompiledMod(vllm_config=vllm_config)(*args) + + +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) +def test_save_and_load(monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m: + args = (torch.randn(10, 10),) + + with tempfile.TemporaryDirectory() as tmpdirname: + m.setenv("VLLM_CACHE_ROOT", tmpdirname) + m.setenv("VLLM_USE_AOT_COMPILE", "1") + vllm_config = make_vllm_config() + with use_vllm_config(vllm_config): + expected = CompiledMod(vllm_config=vllm_config)(*args) + + m.setenv("VLLM_FORCE_AOT_LOAD", "1") + vllm_config = make_vllm_config() + with use_vllm_config(vllm_config): + ret = CompiledMod(vllm_config=vllm_config)(*args) + assert torch.allclose(ret, expected) + + +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) +def test_shape_env(monkeypatch: pytest.MonkeyPatch): + """ + Test that the shape environment is correctly serialized and preserved + when loading from cache. + """ + with monkeypatch.context() as m: + args = (torch.randn(10, 10),) + + with tempfile.TemporaryDirectory() as tmpdirname: + m.setenv("VLLM_CACHE_ROOT", tmpdirname) + m.setenv("VLLM_USE_AOT_COMPILE", "1") + vllm_config = make_vllm_config() + with use_vllm_config(vllm_config): + compiled_mod = CompiledMod(vllm_config=vllm_config) + compiled_mod(*args) + artifacts = compiled_mod.aot_compiled_fn._artifacts + guards_string = artifacts.compiled_fn.shape_env.format_guards() + assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)" + + m.setenv("VLLM_FORCE_AOT_LOAD", "1") + vllm_config = make_vllm_config() + with use_vllm_config(vllm_config): + compiled_mod = CompiledMod(vllm_config=vllm_config) + compiled_mod(*args) + artifacts = compiled_mod.aot_compiled_fn._artifacts + guards_string = artifacts.compiled_fn.shape_env.format_guards() + assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)" diff --git a/tests/compile/test_async_tp.py b/tests/compile/test_async_tp.py index 88ad4f81df50..71ee22878143 100644 --- a/tests/compile/test_async_tp.py +++ b/tests/compile/test_async_tp.py @@ -10,6 +10,7 @@ from vllm.compilation.collective_fusion import AsyncTPPass from vllm.config import ( CompilationConfig, + CompilationMode, DeviceConfig, ModelConfig, PassConfig, @@ -24,7 +25,7 @@ initialize_model_parallel, ) from vllm.platforms import current_platform -from vllm.utils import update_environment_variables +from vllm.utils.system_utils import update_environment_variables from ..models.registry import HF_EXAMPLE_MODELS from ..utils import ( @@ -142,7 +143,7 @@ def ops_in_model_before(self): return [torch.ops.vllm.reduce_scatter.default] def ops_in_model_after(self): - return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default] + return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default] class TestAGScaledMMModel(_BaseScaledMMModel): @@ -195,7 +196,7 @@ def ops_in_model_before(self): return [torch.ops.vllm.reduce_scatter.default] def ops_in_model_after(self): - return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default] + return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default] class TestAGCutlassScaledMMModel(_BaseScaledMMModel): @@ -243,9 +244,15 @@ def ops_in_model_after(self): @pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dynamic", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") def test_async_tp_pass_replace( - test_model: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype + test_model: str, + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, + dynamic: bool, ): if ( test_model @@ -269,7 +276,15 @@ def run_torch_spawn(fn, nprocs): # torch.distributed and cuda torch.multiprocessing.spawn( fn, - args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype), + args=( + num_processes, + test_model, + batch_size, + seq_len, + hidden_size, + dtype, + dynamic, + ), nprocs=nprocs, ) @@ -284,6 +299,7 @@ def async_tp_pass_on_test_model( seq_len: int, hidden_size: int, dtype: torch.dtype, + dynamic: bool, ): current_platform.seed_everything(0) @@ -317,7 +333,7 @@ def async_tp_pass_on_test_model( # this is a fake model name to construct the model config # in the vllm_config, it's not really used. - model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" + model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8" vllm_config.model_config = ModelConfig( model=model_name, trust_remote_code=True, dtype=dtype, seed=42 ) @@ -325,12 +341,24 @@ def async_tp_pass_on_test_model( async_tp_pass = AsyncTPPass(vllm_config) backend = TestBackend(async_tp_pass) + assert ( + async_tp_pass.compilation_config.splitting_ops + == vllm_config.compilation_config.splitting_ops + ) + assert ( + async_tp_pass.compilation_config.use_inductor_graph_partition + == vllm_config.compilation_config.use_inductor_graph_partition + ) + model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor hidden_states = torch.randn( (batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False ) + if dynamic: + torch._dynamo.mark_dynamic(hidden_states, 0) + compiled_model = torch.compile(model, backend=backend) compiled_model(hidden_states) @@ -382,7 +410,7 @@ def test_async_tp_pass_correctness( common_args.append("--enforce-eager") compilation_config = { - "level": 3, + "mode": CompilationMode.VLLM_COMPILE, "compile_sizes": [2, 4, 8], "splitting_ops": [], "pass_config": {"enable_async_tp": async_tp_enabled}, diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 4bcefb30b2e6..132a838b8d44 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -1,13 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import dataclasses import pytest -from vllm.config import CompilationLevel -from vllm.utils import cuda_device_count_stateless +from vllm.config import CompilationMode +from vllm.utils.torch_utils import cuda_device_count_stateless from ..utils import compare_all_settings @@ -23,7 +21,7 @@ class TestSetting: # we cannot afford testing the full Cartesian product -# of all models and all levels +# of all models and all modes @pytest.mark.parametrize( "test_setting", [ @@ -79,14 +77,15 @@ class TestSetting: method="encode", ), # vision language model - TestSetting( - model="microsoft/Phi-3.5-vision-instruct", - model_args=["--trust-remote-code", "--max-model-len", "2048"], - pp_size=2, - tp_size=1, - attn_backend="FLASH_ATTN", - method="generate_with_image", - ), + # See https://github.com/vllm-project/vllm/issues/26716. + # TestSetting( + # model="microsoft/Phi-3.5-vision-instruct", + # model_args=["--trust-remote-code", "--max-model-len", "2048"], + # pp_size=2, + # tp_size=1, + # attn_backend="FLASH_ATTN", + # method="generate_with_image", + # ), ], ) def test_compile_correctness( @@ -111,41 +110,44 @@ def test_compile_correctness( with monkeypatch.context() as m: m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) final_args = [ - "--enforce-eager", *model_args, "-pp", str(pp_size), "-tp", str(tp_size), + "-O.cudagraph_mode=none", ] all_args: list[list[str]] = [] all_envs: list[dict[str, str] | None] = [] - for level in [ - CompilationLevel.NO_COMPILATION, - CompilationLevel.PIECEWISE, + for comp_mode in [ + CompilationMode.STOCK_TORCH_COMPILE, + CompilationMode.DYNAMO_TRACE_ONCE, + CompilationMode.VLLM_COMPILE, ]: - all_args.append(final_args + [f"-O{level}"]) - all_envs.append({}) + for mode in [CompilationMode.NONE, comp_mode]: + all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=inductor"]) - # inductor will change the output, so we only compare if the output - # is close, not exactly the same. - compare_all_settings( - model, - all_args, - all_envs, - method=method if method != "generate" else "generate_close", - ) - all_envs.clear() - all_args.clear() + # inductor will change the output, so we only compare if the output + # is close, not exactly the same. + compare_all_settings( + model, + all_args, + all_envs, + method=method if method != "generate" else "generate_close", + ) + all_envs.clear() + all_args.clear() - for level in [ - CompilationLevel.NO_COMPILATION, - CompilationLevel.DYNAMO_AS_IS, - CompilationLevel.DYNAMO_ONCE, + for mode in [ + CompilationMode.NONE, + CompilationMode.STOCK_TORCH_COMPILE, + CompilationMode.DYNAMO_TRACE_ONCE, + CompilationMode.VLLM_COMPILE, ]: - all_args.append(final_args + [f"-O{level}"]) + all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=eager"]) + all_envs.append({}) all_envs.append({}) compare_all_settings(model, all_args * 3, all_envs, method=method) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 0da7f58a2f5f..4145e84c2ee0 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -1,13 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy +from contextlib import nullcontext + import pytest from vllm.compilation.counter import compilation_counter +from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig -from vllm.utils import _is_torch_equal_or_newer +from vllm.config.compilation import CompilationMode +from vllm.engine.arg_utils import EngineArgs +from vllm.platforms import current_platform +from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer def test_version(): + # Test the version comparison logic using the private function assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev") assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev") assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev") @@ -17,9 +25,26 @@ def test_version(): def test_use_cudagraphs_dynamic(): vllm_config = VllmConfig() + # Default V1 configuration now starts without cudagraphs enabled; the + # engine decides when to capture based on runtime settings instead of a + # blanket default. assert vllm_config.compilation_config.use_cudagraph +def test_copy_pass(): + vllm_config = VllmConfig() + inductor_pass = FixFunctionalizationPass(vllm_config) + copied_inductor_pass = copy.deepcopy(inductor_pass) + assert ( + copied_inductor_pass.compilation_config.use_inductor_graph_partition + == vllm_config.compilation_config.use_inductor_graph_partition + ) + assert ( + copied_inductor_pass.compilation_config.splitting_ops + == vllm_config.compilation_config.splitting_ops + ) + + def test_custom_op(): # proper syntax _ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"]) @@ -85,16 +110,16 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled): # forked needed to workaround https://github.com/vllm-project/vllm/issues/21073 @pytest.mark.forked -def test_dynamo_as_is(vllm_runner, monkeypatch): +def test_stock_torch_compile(vllm_runner, monkeypatch): # Disable multiprocessing so that the counter is in the same process monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") with ( - compilation_counter.expect(dynamo_as_is_count=1), + compilation_counter.expect(stock_torch_compile_count=1), # loading the model causes compilation (if enabled) to happen vllm_runner( "facebook/opt-125m", - compilation_config={"level": 1}, + compilation_config={"mode": CompilationMode.STOCK_TORCH_COMPILE}, gpu_memory_utilization=0.4, ) as _, ): @@ -107,11 +132,11 @@ def test_no_compilation(vllm_runner, monkeypatch): # Disable multiprocessing so that the counter is in the same process monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") with ( - compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0), + compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0), # loading the model causes compilation (if enabled) to happen vllm_runner( "facebook/opt-125m", - compilation_config={"level": 0}, + compilation_config={"mode": CompilationMode.NONE}, gpu_memory_utilization=0.4, ) as _, ): @@ -125,7 +150,7 @@ def test_enforce_eager(vllm_runner, monkeypatch): monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") with ( - compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0), + compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0), # loading the model causes compilation (if enabled) to happen vllm_runner( "facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4 @@ -137,58 +162,147 @@ def test_enforce_eager(vllm_runner, monkeypatch): def test_splitting_ops_dynamic(): # Default config config = VllmConfig() - assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE - assert config.compilation_config.splitting_ops_contain_attention() + # Default V1 config leaves cudagraph mode unset; splitting ops are only + # populated when the engine decides to use piecewise compilation. + assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE + assert not config.compilation_config.splitting_ops_contain_attention() # When use_inductor_graph_partition=True - if _is_torch_equal_or_newer("2.9.0.dev"): - # inductor graph partition is only available in PyTorch 2.9+. - # this is a fast config check so we are not using pytest.skip. + if is_torch_equal_or_newer("2.9.0.dev"): config = VllmConfig( compilation_config=CompilationConfig( - use_inductor_graph_partition=True, splitting_ops=["silly_attention"] + mode=CompilationMode.VLLM_COMPILE, + use_inductor_graph_partition=True, + splitting_ops=["vllm::unified_attention"], ) ) - # should ignore splitting_ops - assert config.compilation_config.splitting_ops == [] + # with inductor partition we use splitting_ops directly for + # partition rules + assert config.compilation_config.splitting_ops == ["vllm::unified_attention"] - # When attn_fusion pass enabled. + # When attn_fusion pass enabled, splitting_ops now default to attention ops. config = VllmConfig( compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, pass_config={"enable_attn_fusion": True, "enable_noop": True}, custom_ops=["+quant_fp8"], cudagraph_mode=CUDAGraphMode.PIECEWISE, ) ) - assert config.compilation_config.splitting_ops == [] - # cudagraph mode also fall back to FULL - assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL - - # splitting_ops can not contain attention ops when attn_fusion - # pass enabled. - with pytest.raises(AssertionError): - config = VllmConfig( - compilation_config=CompilationConfig( - pass_config={"enable_attn_fusion": True, "enable_noop": True}, - custom_ops=["+quant_fp8"], - cudagraph_mode=CUDAGraphMode.PIECEWISE, - # work around for accessing all attntion ops - splitting_ops=CompilationConfig()._attention_ops, - ) - ) + # With the new simplified logic, attention fusion works with splitting_ops + assert config.compilation_config.splitting_ops_contain_attention() + # cudagraph mode remains PIECEWISE + assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE # When both use_inductor_graph_partition and attn_fusion pass enabled. - if _is_torch_equal_or_newer("2.9.0.dev"): + if is_torch_equal_or_newer("2.9.0.dev"): config = VllmConfig( compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, use_inductor_graph_partition=True, pass_config={"enable_attn_fusion": True, "enable_noop": True}, custom_ops=["+quant_fp8"], cudagraph_mode=CUDAGraphMode.PIECEWISE, ) ) - assert config.compilation_config.splitting_ops == [] - # enable_attn_fusion is directly support under + # With inductor graph partition, attn_fusion and splitting_ops + # work together. Default splitting_ops include attention ops. + assert config.compilation_config.splitting_ops_contain_attention() + # enable_attn_fusion is directly supported under # use_inductor_graph_partition=True, and cudagraph_mode # is unchanged. assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE + + +def test_resolve_operator_overload(): + import torch + + from vllm.compilation.partition_rules import resolve_defined_ops + + # Test valid operator names + resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"]) + assert len(resolved) == 2 + assert resolved[0] is torch.ops.aten.mm.default + assert resolved[1] is torch.ops.aten.addmm.default + + # Test that invalid operators are skipped (not raising exceptions) + resolved = resolve_defined_ops( + [ + "aten::mm.default", + "aten::nonexistent_op.default", # This should be skipped + "aten::addmm.default", + ] + ) + assert len(resolved) == 2 # Only 2 valid ops + assert resolved[0] is torch.ops.aten.mm.default + assert resolved[1] is torch.ops.aten.addmm.default + + +@pytest.mark.skipif( + not current_platform.support_static_graph_mode(), + reason="Skip if not cudagraph mode supported", +) +@pytest.mark.parametrize( + ( + "cudagraph_capture_sizes", + "max_cudagraph_capture_size", + "tp_size", + "enable_sequence_parallelism", + "max_num_batched_tokens", + "use_cudagraph", + "expected_max_size", + ), + [ + (None, None, 1, False, 2048, True, 512), + ([1, 2, 4], 4, 1, False, 2048, True, 4), + ([1, 2, 4], 8, 1, False, 2048, True, RuntimeError), + ([1, 256], None, 1, False, 2048, 256), + ([], None, 1, False, 2048, False, 0), + (None, 0, 1, False, 2048, False, 0), + # truncated to nearest multiple of 8 or 16 + (None, 257, 1, False, 2048, True, 256), + ([1, 2, 4, 15], None, 1, False, 2048, True, 15), # max from list + ([1, 2, 4, 15], None, 2, True, 2048, True, 4), # filtered out 15 due to SP + ([1, 2, 4, 15], None, 1, False, 8, True, 4), # limited by the max_tokens + # the list should contain at least 1 element when use cudagraph + ([], None, 1, False, 2048, True, RuntimeError), + # the max capturing size should be >= 1 when use cudagraph + (None, 0, 1, False, 2048, True, RuntimeError), + ], +) +def test_cudagraph_sizes_post_init( + cudagraph_capture_sizes, + max_cudagraph_capture_size, + tp_size, + enable_sequence_parallelism, + max_num_batched_tokens, + use_cudagraph, + expected_max_size, +): + ctx = nullcontext() + if isinstance(expected_max_size, Exception): + ctx = pytest.raises(expected_max_size) + + cudagraph_mode = CUDAGraphMode.PIECEWISE if use_cudagraph else CUDAGraphMode.NONE + with ctx: + compilation_config = CompilationConfig( + cudagraph_capture_sizes=cudagraph_capture_sizes, + max_cudagraph_capture_size=max_cudagraph_capture_size, + pass_config={ + "enable_sequence_parallelism": enable_sequence_parallelism, + "enable_fusion": True, + "enable_noop": True, + }, + cudagraph_mode=cudagraph_mode, + ) + engine_args = EngineArgs( + model="facebook/opt-125m", + tensor_parallel_size=tp_size, + max_num_batched_tokens=max_num_batched_tokens, + compilation_config=compilation_config, + ) + vllm_config = engine_args.create_engine_config() + + assert ( + vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size + ) diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index d7048821bb60..c9d01f2317d2 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest import torch from torch import nn @@ -8,12 +9,13 @@ from vllm.config import ( CacheConfig, CompilationConfig, - CompilationLevel, + CompilationMode, CUDAGraphMode, VllmConfig, set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.utils.torch_utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from . import silly_attention # noqa: F401 @@ -65,18 +67,40 @@ def run_model( return output.cpu() -def test_ignore_torch_compile_decorator(): +@pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) +def test_ignore_torch_compile_decorator(use_inductor_graph_partition, monkeypatch): + # disable compile cache so that we can count the number of compilations + # appropriately + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + # piecewise vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, use_cudagraph=True, - splitting_ops=["silly.attention"], + splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], + use_inductor_graph_partition=use_inductor_graph_partition, ) ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE + expected_num_graphs_seen = 1 + expected_num_cudagraph_captured = ( + 4 # num_cudagraph_sizes * num cudagraphs to capture + ) + if use_inductor_graph_partition: + expected_num_piecewise_graphs_seen = 1 + expected_num_piecewise_capturable_graphs_seen = 1 + expected_num_backend_compilations = 1 + else: + expected_num_piecewise_graphs_seen = 3 + expected_num_piecewise_capturable_graphs_seen = 2 + expected_num_backend_compilations = 2 + @support_torch_compile class A(nn.Module): def __init__( @@ -103,12 +127,11 @@ class C(B): ... # A has support_torch_compile with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=expected_num_graphs_seen, + num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, + num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen, + num_backend_compilations=expected_num_backend_compilations, + num_cudagraph_captured=expected_num_cudagraph_captured, ): run_model(vllm_config, mod_A, cudagraph_runtime_mode) @@ -130,12 +153,11 @@ class C(B): ... # C's support_torch_compile should override B's ignore_torch_compile with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=expected_num_graphs_seen, + num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, + num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen, + num_backend_compilations=expected_num_backend_compilations, + num_cudagraph_captured=expected_num_cudagraph_captured, ): run_model(vllm_config, mod_C, cudagraph_runtime_mode) @@ -178,16 +200,25 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def test_conditional_compile_enable_if(): +@pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) +def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch): + # disable compile cache so that we can count the number of compilations + # appropriately + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + vllm_config = VllmConfig( cache_config=CacheConfig( kv_sharing_fast_prefill=True, ), compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, use_cudagraph=True, - splitting_ops=["silly.attention"], + splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], + use_inductor_graph_partition=use_inductor_graph_partition, ), ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE @@ -195,17 +226,26 @@ def test_conditional_compile_enable_if(): with set_current_vllm_config(vllm_config): mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda() + if use_inductor_graph_partition: + expected_num_piecewise_graphs_seen = 2 + expected_num_piecewise_capturable_graphs_seen = 2 + expected_num_backend_compilations = 2 + else: + expected_num_piecewise_graphs_seen = 6 + expected_num_piecewise_capturable_graphs_seen = 4 + expected_num_backend_compilations = 4 + # A has support_torch_compile but enable_if fn returns False # enalbe_if will be True for B, so we expect mod1 and mod2 # to be compiled with compilation_counter.expect( num_graphs_seen=2, - num_piecewise_graphs_seen=6, + num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, # 3 piecewise graphs per instance of B() - num_piecewise_capturable_graphs_seen=4, - num_backend_compilations=4, + num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen, + num_backend_compilations=expected_num_backend_compilations, num_cudagraph_captured=8, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + # num_cudagraph_sizes * num cudagraphable graphs to capture ): run_model(vllm_config, mod_A, cudagraph_runtime_mode) @@ -216,23 +256,34 @@ def test_conditional_compile_enable_if(): kv_sharing_fast_prefill=False, ), compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, use_cudagraph=True, - splitting_ops=["silly.attention"], + splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], + use_inductor_graph_partition=use_inductor_graph_partition, ), ) with set_current_vllm_config(vllm_config): mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda() + if use_inductor_graph_partition: + expected_num_piecewise_graphs_seen = 1 + expected_num_piecewise_capturable_graphs_seen = 1 + expected_num_backend_compilations = 1 + else: + # 3 attn ops and 4 non-attn ops + expected_num_piecewise_graphs_seen = 7 + expected_num_piecewise_capturable_graphs_seen = 4 + expected_num_backend_compilations = 4 + with compilation_counter.expect( num_graphs_seen=1, - num_piecewise_graphs_seen=7, + num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, # 3 attn ops and 4 non-attn ops - num_piecewise_capturable_graphs_seen=4, - num_backend_compilations=4, + num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen, + num_backend_compilations=expected_num_backend_compilations, num_cudagraph_captured=8, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + # num_cudagraph_sizes * num cudagraphable graphs to capture ): run_model(vllm_config, mod_A, cudagraph_runtime_mode) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 8ccae4cfb9df..0ad8c17d8668 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -1,22 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - -import logging import tempfile -from typing import Any, Union +from pathlib import Path +from typing import Any import pytest import torch from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.attention.backends.registry import _Backend -from vllm.attention.selector import global_force_attn_backend_context_manager -from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig +from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer from ..utils import create_new_process_for_each_test @@ -24,23 +20,24 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None): TEST_MODELS: list[tuple[str, dict[str, Any]]] = [ ("facebook/opt-125m", {}), - ( - "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", - { - "dtype": torch.float16, - }, - ), ( "neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", - { - "dtype": torch.float16, - }, + {"dtype": torch.float16}, ), - ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}), ("meta-llama/Llama-3.2-1B-Instruct", {}), ] if all: + TEST_MODELS.extend( + [ + ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}), + ( + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + {"dtype": torch.float16}, + ), + ] + ) + # TODO: figure out why this fails. if False and is_quant_method_supported("gguf"): # noqa: SIM223 TEST_MODELS.append( @@ -82,70 +79,79 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None): @pytest.mark.parametrize( - "optimization_level", - [CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE], + "compilation_mode", + [CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.VLLM_COMPILE], ) -@pytest.mark.parametrize("model_info", models_list(all=True)) +@pytest.mark.parametrize("model, model_kwargs", models_list(all=True)) @create_new_process_for_each_test() def test_full_graph( monkeypatch: pytest.MonkeyPatch, - model_info: tuple[str, dict[str, Any]], - optimization_level: int, + model: str, + model_kwargs: dict[str, Any], + compilation_mode: int, ): - model, model_kwargs = model_info + if ( + "w8a8" in model + or "w8w8" in model + and current_platform.has_device_capability((10, 0)) + ): + # int8 removed on Blackwell: + pytest.skip("int8 support removed on Blackwell") with monkeypatch.context(): print(f"MODEL={model}") - run_model(optimization_level, model, model_kwargs) + run_model(compilation_mode, model, **model_kwargs) # TODO(luka) add other supported compilation config scenarios here @pytest.mark.parametrize( - "compilation_config, model_info", + "compilation_config, model, model_kwargs", [ # additional compile sizes, only some of the models ( - CompilationConfig(level=CompilationLevel.PIECEWISE, compile_sizes=[1, 2]), - model, + CompilationConfig(mode=CompilationMode.VLLM_COMPILE, compile_sizes=[1, 2]), + *model_info, ) - for model in models_list(all=False) + for model_info in models_list(all=False) ] + [ # RMSNorm + quant fusion, only 8-bit quant models ( CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, custom_ops=["+rms_norm"], pass_config=PassConfig(enable_fusion=True, enable_noop=True), ), - model, + *model_info, ) - for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"]) + for model_info in models_list(keywords=["FP8-dynamic", "quantized.w8a8"]) ] + [ # Test depyf integration works ( CompilationConfig( - level=CompilationLevel.PIECEWISE, debug_dump_path=tempfile.gettempdir() + mode=CompilationMode.VLLM_COMPILE, + debug_dump_path=Path(tempfile.gettempdir()), ), - ("facebook/opt-125m", {}), + "facebook/opt-125m", + {}, ), ] + [ # graph inductor partition ( CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, # inductor graph partition uses # torch._C.Tag.cudagraph_unsafe to specify splitting ops use_inductor_graph_partition=True, cudagraph_mode=CUDAGraphMode.PIECEWISE, compile_sizes=[1, 2], ), - model, + *model_info, ) - for model in models_list(all=False) + for model_info in models_list(all=False) if is_torch_equal_or_newer("2.9.0.dev") ], ) @@ -153,23 +159,31 @@ def test_full_graph( @create_new_process_for_each_test() def test_custom_compile_config( compilation_config: CompilationConfig, - model_info: tuple[str, dict[str, Any]], + model: str, + model_kwargs: dict[str, Any], ): + if ( + "w8a8" in model + or "w8w8" in model + and current_platform.has_device_capability((10, 0)) + ): + # int8 removed on Blackwell: + pytest.skip("int8 support removed on Blackwell") + if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer( "2.9.0.dev" ): pytest.skip("inductor graph partition is only available in PyTorch 2.9+") - model, model_kwargs = model_info print(f"MODEL={model}") - run_model(compilation_config, model, model_kwargs) + run_model(compilation_config, model, **model_kwargs) @pytest.mark.parametrize( - "optimization_level", - [CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE], + "compilation_mode", + [CompilationMode.NONE, CompilationMode.VLLM_COMPILE], ) -def test_fp8_kv_scale_compile(optimization_level: int): +def test_fp8_kv_scale_compile(compilation_mode: int): model = "Qwen/Qwen2-0.5B" model_kwargs = { "quantization": "fp8", @@ -177,50 +191,16 @@ def test_fp8_kv_scale_compile(optimization_level: int): "calculate_kv_scales": True, "max_model_len": 512, } - run_model(optimization_level, model, model_kwargs) - + run_model(compilation_mode, model, **model_kwargs) -def test_inductor_graph_partition_attn_fusion(caplog_vllm): - if not is_torch_equal_or_newer("2.9.0.dev"): - pytest.skip("inductor graph partition is only available in PyTorch 2.9+") - model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" - compilation_config = CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_inductor_graph_partition=True, - cudagraph_mode=CUDAGraphMode.PIECEWISE, - custom_ops=["+quant_fp8"], - pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True), +def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs): + compilation_config = ( + compile_config + if isinstance(compile_config, CompilationConfig) + else CompilationConfig(mode=compile_config) ) - model_kwargs = { - "kv_cache_dtype": "fp8", - "max_model_len": 1024, - } - with ( - caplog_vllm.at_level(logging.DEBUG), - global_force_attn_backend_context_manager(_Backend.FLASHINFER), - ): - run_model(compilation_config, model, model_kwargs) - try: - assert "Fused quantization onto 48 attention nodes" in caplog_vllm.text, ( - caplog_vllm.text - ) - except AssertionError: - # Note: this message is only triggered when the compilation goes - # through the custom pass. Due to multiple layers of cache on - # PyTorch side, the compilation of a graph may be cached such - # that custom pass directly goes through cache. In this case, - # we go through this branch and assert that the pass is not - # triggered. - assert "Fused quantization" not in caplog_vllm.text - - -def run_model( - compile_config: Union[int, CompilationConfig], - model: str, - model_kwargs: dict[str, Any], -): prompts = [ "Hello, my name is", "The president of the United States is", @@ -228,12 +208,17 @@ def run_model( "The future of AI is", ] sampling_params = SamplingParams(temperature=0) + # Allow override from model_kwargs + model_kwargs = {"tensor_parallel_size": 1, **model_kwargs} + model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs} + + # No cudagraphs by default + if compilation_config.cudagraph_mode is None: + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + llm = LLM( model=model, - enforce_eager=True, - tensor_parallel_size=1, - disable_custom_all_reduce=True, - compilation_config=compile_config, + compilation_config=compilation_config, **model_kwargs, ) outputs = llm.generate(prompts, sampling_params) diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index ae17bc67b1fb..11ae96e930da 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -11,7 +11,13 @@ from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import CompilationConfig, PassConfig, VllmConfig +from vllm.config import ( + CompilationConfig, + ModelConfig, + PassConfig, + VllmConfig, + set_current_vllm_config, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape @@ -48,8 +54,7 @@ def forward(self, x): return y def example_inputs(self, num_tokens=32, hidden_size=128): - dtype = torch.float16 if TEST_FP8 else torch.float32 - return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),) + return (torch.rand(num_tokens, hidden_size * 2),) def ops_in_model(self, do_fusion): if TEST_FP8 and do_fusion: @@ -67,15 +72,11 @@ def __init__(self, hidden_size=16, intermediate_size=32): self.hidden_size = hidden_size self.intermediate_size = intermediate_size - dtype = torch.float16 if TEST_FP8 else torch.float32 - self.gate_proj = torch.nn.Parameter( - torch.empty((intermediate_size, hidden_size), dtype=dtype) + torch.empty((intermediate_size, hidden_size)) ) self.norm = RMSNorm(intermediate_size, 1e-05) - self.norm.weight = torch.nn.Parameter( - torch.ones(intermediate_size, dtype=dtype) - ) + self.norm.weight = torch.nn.Parameter(torch.ones(intermediate_size)) torch.nn.init.normal_(self.gate_proj, std=0.02) @@ -112,9 +113,8 @@ def forward(self, hidden_states, residual): return norm_output, residual_output def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16): - dtype = torch.float16 if TEST_FP8 else torch.float32 - hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) - residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + hidden_states = torch.randn((batch_size * seq_len, hidden_size)) + residual = torch.randn((batch_size * seq_len, hidden_size)) return (hidden_states, residual) def ops_in_model(self, do_fusion): @@ -145,10 +145,9 @@ def forward(self, positions, q, k): return q_rotated, k_rotated def example_inputs(self, num_tokens=32, head_dim=64): - dtype = torch.float16 positions = torch.arange(num_tokens, dtype=torch.long) - q = torch.randn(num_tokens, head_dim, dtype=dtype) - k = torch.randn(num_tokens, head_dim, dtype=dtype) + q = torch.randn(num_tokens, head_dim) + k = torch.randn(num_tokens, head_dim) return (positions, q, k) def ops_in_model(self, do_fusion): @@ -166,7 +165,7 @@ def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000): self.hidden_size = head_dim * num_heads self.qkv_proj = torch.nn.Linear( - self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16 + self.hidden_size, self.hidden_size * 3, bias=False ) self.rotary_emb = get_rope( @@ -190,10 +189,9 @@ def forward(self, positions, hidden_states): return qkv_updated def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4): - dtype = torch.float16 hidden_size = head_dim * num_heads positions = torch.arange(num_tokens, dtype=torch.long) - hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) + hidden_states = torch.randn(num_tokens, hidden_size) return (positions, hidden_states) def ops_in_model(self, do_fusion): @@ -211,48 +209,58 @@ def ops_not_in_model(self): ] +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("model_class", MODELS) @pytest.mark.parametrize("do_fusion", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA") -def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool): +def test_fix_functionalization( + model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype +): torch.set_default_device("cuda") - - vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig( - pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True) + torch.set_default_dtype(dtype) + + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), + compilation_config=CompilationConfig( + custom_ops=["all"], + pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True), + ), ) - noop_pass = NoOpEliminationPass(vllm_config) - fusion_pass = RMSNormQuantFusionPass(vllm_config) - cleanup_pass = PostCleanupPass(vllm_config) - act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) - - passes = ( - [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass] - if do_fusion - else [noop_pass, cleanup_pass] - ) - func_pass = FixFunctionalizationPass(vllm_config) - backend_func = TestBackend(*passes, func_pass) - backend_no_func = TestBackend(*passes) + with set_current_vllm_config(vllm_config): + assert RMSNorm.enabled() + noop_pass = NoOpEliminationPass(vllm_config) + fusion_pass = RMSNormQuantFusionPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) + + passes = ( + [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass] + if do_fusion + else [noop_pass, cleanup_pass] + ) + func_pass = FixFunctionalizationPass(vllm_config) - model = model_class() - torch.compile(model, backend=backend_func)(*model.example_inputs()) - torch.compile(model, backend=backend_no_func)(*model.example_inputs()) + backend_func = TestBackend(*passes, func_pass) + backend_no_func = TestBackend(*passes) - # check if the functionalization pass is applied - for op in model.ops_in_model(do_fusion): - find_auto_fn(backend_no_func.graph_post_pass.nodes, op) - assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None + model = model_class() + torch.compile(model, backend=backend_func)(*model.example_inputs()) + torch.compile(model, backend=backend_no_func)(*model.example_inputs()) - # make sure the ops were all de-functionalized - found = dict() - for node in backend_func.graph_post_pass.nodes: + # check if the functionalization pass is applied for op in model.ops_in_model(do_fusion): - if is_func(node, op): - found[op] = True - for op in model.ops_not_in_model(): - if is_func(node, op): - found[op] = True - assert all(found[op] for op in model.ops_in_model(do_fusion)) - assert all(not found.get(op) for op in model.ops_not_in_model()) + find_auto_fn(backend_no_func.graph_post_pass.nodes, op) + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None + + # make sure the ops were all de-functionalized + found = dict() + for node in backend_func.graph_post_pass.nodes: + for op in model.ops_in_model(do_fusion): + if is_func(node, op): + found[op] = True + for op in model.ops_not_in_model(): + if is_func(node, op): + found[op] = True + assert all(found[op] for op in model.ops_in_model(do_fusion)) + assert all(not found.get(op) for op in model.ops_not_in_model()) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 7c2233643229..286f2276367a 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -5,15 +5,18 @@ import torch import vllm.plugins -from vllm.compilation.fusion import ( - FUSED_OPS, - QUANT_OPS, - FusedRMSQuantKey, - RMSNormQuantFusionPass, -) +from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass +from vllm.compilation.fx_utils import find_op_nodes +from vllm.compilation.matcher_utils import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig +from vllm.config import ( + CompilationConfig, + CompilationMode, + ModelConfig, + PassConfig, + VllmConfig, +) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -32,6 +35,9 @@ FP8_DTYPE = current_platform.fp8_dtype() +RMS_OP = torch.ops._C.rms_norm.default +RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default + class TestModel(torch.nn.Module): def __init__( @@ -45,18 +51,18 @@ def __init__( ): super().__init__(*args, **kwargs) self.cuda_force_torch = cuda_force_torch - self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] - self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN quant_scale = ScaleDesc(torch.float32, static, group_shape) - self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) + self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) if static: - self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] else: - self.scale = [None for _ in range(2)] + self.scale = [None for _ in range(3)] self.w = [ torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() - for _ in range(2) + for _ in range(3) ] with override_cutlass_fp8_supported(not cuda_force_torch): @@ -65,8 +71,12 @@ def __init__( act_quant_group_shape=group_shape, ) + self.enable_rms_norm_custom_op = self.norm[0].enabled() + self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() + def forward(self, x): - resid = torch.sqrt(x) + # avoid having graph input be an arg to a pattern directly + x = resid = torch.relu(x) y = self.norm[0](x) x2 = self.fp8_linear.apply( @@ -78,24 +88,44 @@ def forward(self, x): x3 = self.fp8_linear.apply( y2, self.w[1], self.wscale[1], input_scale=self.scale[1] ) + y3, resid = self.norm[2](x3, resid) # use resid here - return y3 - def ops_in_model_before(self): - return [QUANT_OPS[self.key]] + x4 = self.fp8_linear.apply( + y3, self.w[2], self.wscale[2], input_scale=self.scale[2] + ) + + y4, resid = self.norm[3](x4, resid) # use resid here + return y4 def ops_in_model_after(self): return [ - FUSED_OPS[FusedRMSQuantKey(self.key, False)], - FUSED_OPS[FusedRMSQuantKey(self.key, True)], + FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)], + FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)], ] + def ops_in_model_before(self): + return ( + [QUANT_OPS[self.quant_key]] + if self.enable_quant_fp8_custom_op + else [torch.ops.aten.reciprocal] + ) + + def ops_in_model_before_partial(self): + return ( + [RMS_OP, RMS_ADD_OP] + if self.enable_rms_norm_custom_op + else [torch.ops.aten.rsqrt] + ) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) +@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) +@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. @pytest.mark.parametrize( @@ -105,19 +135,32 @@ def ops_in_model_after(self): not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm" ) def test_fusion_rmsnorm_quant( - dtype, hidden_size, num_tokens, eps, static, cuda_force_torch + dtype, + hidden_size, + num_tokens, + eps, + static, + enable_rms_norm_custom_op, + enable_quant_fp8_custom_op, + cuda_force_torch, ): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) maybe_create_device_identity() # needed for certain non-cutlass fp8 paths + custom_ops = [] + if enable_rms_norm_custom_op: + custom_ops.append("+rms_norm") + if enable_quant_fp8_custom_op: + custom_ops.append("+quant_fp8") vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - custom_ops=["+rms_norm", "+quant_fp8"], + mode=CompilationMode.VLLM_COMPILE, + custom_ops=custom_ops, pass_config=PassConfig(enable_fusion=True, enable_noop=True), - ) + ), ) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work @@ -126,31 +169,39 @@ def test_fusion_rmsnorm_quant( cleanup_pass = PostCleanupPass(vllm_config) backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + backend2 = TestBackend(noop_pass, cleanup_pass) model = TestModel(hidden_size, eps, static, cuda_force_torch) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) torch._dynamo.mark_dynamic(x, 0) - result = model(x) + model_fused = torch.compile(model, backend=backend) + result_fused = model_fused(x) - model2 = torch.compile(model, backend=backend) - result2 = model2(x) + model_unfused = torch.compile(model, backend=backend2) + result_unfused = model_unfused(x) - # Higher tol for dynamic, even higher for bfloat16 - if static: - ATOL, RTOL = (1e-3, 1e-3) - elif dtype == torch.float16: + if dtype == torch.float16: ATOL, RTOL = (2e-3, 2e-3) else: ATOL, RTOL = (1e-2, 1e-2) - torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL) - assert fusion_pass.matched_count == 2 - - # In pre-nodes, fp8 quant should be there and fused kernels should not + assert fusion_pass.matched_count == 3 backend.check_before_ops(model.ops_in_model_before()) - - # In post-nodes, fused kernels should be there and fp8 quant should not + backend.check_before_ops( + model.ops_in_model_before_partial(), fully_replaced=False + ) backend.check_after_ops(model.ops_in_model_after()) + + # If RMSNorm custom op is disabled (native/torch impl used), + # there's a risk that the fused add doesn't get included in the + # replacement and only the rms part gets fused with quant. + # Hence, we check only 2 add nodes are left (final fused rmsnorm add). + if not enable_rms_norm_custom_op: + n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g)) + # 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each) + assert n_add_nodes(backend.graph_pre_pass) == 7 + assert n_add_nodes(backend.graph_post_pass) == 2 diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 7e5c460db174..6d0a0ed7d89d 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -6,17 +6,19 @@ import torch import vllm.envs as envs +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.compilation.collective_fusion import AllReduceFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import ( CompilationConfig, - CompilationLevel, + CompilationMode, DeviceConfig, ModelConfig, PassConfig, VllmConfig, + set_current_vllm_config, ) from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( @@ -25,11 +27,11 @@ ) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp, GroupShape, - QuantFP8, ) from vllm.platforms import current_platform -from vllm.utils import update_environment_variables +from vllm.utils.system_utils import update_environment_variables from ..utils import has_module_attribute, multi_gpu_test from .backend import TestBackend @@ -40,33 +42,30 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps - self.norm = RMSNorm(hidden_size, eps) + self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] + self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)] - def forward(self, hidden_states, residual): - view = hidden_states.reshape(-1, self.hidden_size) - all_reduce = tensor_model_parallel_all_reduce(view) - norm = self.norm(all_reduce) - return norm + def forward(self, x): + # avoid having graph input be an arg to a pattern directly + z = torch.relu(x) + x = resid = tensor_model_parallel_all_reduce(z) + y = self.norm[0](x) - def ops_in_model_before(self): - return [torch.ops.vllm.all_reduce.default] + z2 = torch.mm(y, self.w[0]) + x2 = tensor_model_parallel_all_reduce(z2) - def ops_in_model_after(self): - return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] + y2, resid = self.norm[1](x2, resid) + z3 = torch.mm(y2, self.w[1]) + x3 = tensor_model_parallel_all_reduce(z3) -class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): - def __init__(self, hidden_size=16, token_num=16, eps=1e-6): - super().__init__() - self.hidden_size = hidden_size - self.eps = eps - self.norm = RMSNorm(hidden_size, eps) + y3, resid = self.norm[2](x3, resid) + + z4 = torch.mm(y3, self.w[2]) + x4 = tensor_model_parallel_all_reduce(z4) - def forward(self, hidden_states, residual): - view = hidden_states.reshape(-1, self.hidden_size) - all_reduce = tensor_model_parallel_all_reduce(view) - norm, _ = self.norm(all_reduce, residual) - return norm + y4, resid = self.norm[3](x4, resid) + return y4 def ops_in_model_before(self): return [torch.ops.vllm.all_reduce.default] @@ -75,24 +74,53 @@ def ops_in_model_after(self): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] -class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module): +class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps - self.norm = RMSNorm(hidden_size, eps) - self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) - self.scale = torch.rand(1, dtype=torch.float32) - self.output = torch.empty((token_num, hidden_size), dtype=torch.float32) - - def forward(self, hidden_states, residual): - view = hidden_states.reshape(-1, self.hidden_size) - all_reduce = tensor_model_parallel_all_reduce(view) - norm_output, residual_output = self.norm(all_reduce, residual) - torch.ops._C.static_scaled_fp8_quant( - self.output, norm_output.contiguous(), self.scale + self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + self.w = [ + torch.rand(hidden_size, hidden_size) + .to(dtype=current_platform.fp8_dtype()) + .t() + for _ in range(3) + ] + + self.fp8_linear = Fp8LinearOp( + act_quant_static=True, + act_quant_group_shape=GroupShape.PER_TENSOR, + ) + + self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + + def forward(self, hidden_states): + # avoid having graph input be an arg to a pattern directly + z = torch.relu(hidden_states) + x = resid = tensor_model_parallel_all_reduce(z) + y = self.norm[0](x) + + z2 = self.fp8_linear.apply( + y, self.w[0], self.wscale[0], input_scale=self.scale[0] + ) + + x2 = tensor_model_parallel_all_reduce(z2) + y2, resid = self.norm[1](x2, resid) + + z3 = self.fp8_linear.apply( + y2, self.w[1], self.wscale[1], input_scale=self.scale[1] ) - return self.output, residual_output + + x3 = tensor_model_parallel_all_reduce(z3) + y3, resid = self.norm[2](x3, resid) # use resid here + + z4 = self.fp8_linear.apply( + y3, self.w[2], self.wscale[2], input_scale=self.scale[2] + ) + x4 = tensor_model_parallel_all_reduce(z4) + y4, resid = self.norm[3](x4, resid) # use resid here + return y4 def ops_in_model_after(self): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] @@ -100,7 +128,9 @@ def ops_in_model_after(self): def ops_in_model_before(self): return [ torch.ops.vllm.all_reduce.default, - torch.ops._C.static_scaled_fp8_quant.default, + torch.ops._C.static_scaled_fp8_quant.default + if self.fp8_linear.quant_fp8.enabled() + else torch.ops.aten.reciprocal.default, ] @@ -109,25 +139,48 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps - self.norm = RMSNorm(hidden_size, eps) - self.scale = torch.rand(1, dtype=torch.float32) - self.output = torch.empty((token_num, hidden_size), dtype=torch.float32) - - round_up = lambda x, y: (x + y - 1) // y * y - rounded_m = round_up(token_num, 128) - scale_n = hidden_size // 16 - rounded_n = round_up(scale_n, 4) - self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32) - - def forward(self, hidden_states, residual): - view = hidden_states.reshape(-1, self.hidden_size) - all_reduce = tensor_model_parallel_all_reduce(view) - norm_output, residual_output = self.norm(all_reduce, residual) - norm_output = norm_output.reshape(-1, norm_output.shape[-1]) - torch.ops._C.scaled_fp4_quant( - self.output, norm_output, self.output_scale, self.scale + self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] + + self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)] + self.agscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + wgscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] + self.alpha = [1 / (w * a) for w, a in zip(wgscale, self.agscale)] + + wq_gen, wscale_gen = zip( + *(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale)) + ) + self.wq, self.wscale = list(wq_gen), list(wscale_gen) + print(f"{self.wq=}, {self.wscale=}") + + def forward(self, hidden_states): + # avoid having graph input be an arg to a pattern directly + z = torch.relu(hidden_states) + x = resid = tensor_model_parallel_all_reduce(z) + y = self.norm[0](x) + + yq, y_scale = scaled_fp4_quant(y, self.agscale[0]) + z2 = cutlass_scaled_fp4_mm( + yq, self.wq[0], y_scale, self.wscale[0], self.alpha[0], out_dtype=y.dtype + ) + + x2 = tensor_model_parallel_all_reduce(z2) + y2, resid = self.norm[1](x2, resid) + + yq2, y_scale2 = scaled_fp4_quant(y2, self.agscale[1]) + z3 = cutlass_scaled_fp4_mm( + yq2, self.wq[1], y_scale2, self.wscale[1], self.alpha[1], out_dtype=y2.dtype ) - return self.output, residual_output, self.output_scale + + x3 = tensor_model_parallel_all_reduce(z3) + y3, resid = self.norm[2](x3, resid) # use resid here + + yq3, y_scale3 = scaled_fp4_quant(y3, self.agscale[2]) + z4 = cutlass_scaled_fp4_mm( + yq3, self.wq[2], y_scale3, self.wscale[2], self.alpha[2], out_dtype=y3.dtype + ) + x4 = tensor_model_parallel_all_reduce(z4) + y4, resid = self.norm[3](x4, resid) # use resid here + return y4 def ops_in_model_after(self): return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] @@ -141,19 +194,19 @@ def ops_in_model_before(self): @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "test_model", + "test_model, enable_quant_fp8_custom_op", [ - TestAllReduceRMSNormModel, - TestAllReduceFusedAddRMSNormModel, - TestAllReduceFusedAddRMSNormStaticQuantFP8Model, - # TODO: Enable with torch==2.8.0 - # TestAllReduceFusedAddRMSNormStaticQuantFP4Model, + (TestAllReduceRMSNormModel, False), + (TestAllReduceRMSNormStaticQuantFP8Model, True), + (TestAllReduceRMSNormStaticQuantFP8Model, False), + (TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False), ], ) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [8]) -@pytest.mark.parametrize("hidden_size", [16]) +@pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif( not find_spec("flashinfer") @@ -167,6 +220,8 @@ def test_all_reduce_fusion_pass_replace( seq_len: int, hidden_size: int, dtype: torch.dtype, + enable_rms_norm_custom_op, + enable_quant_fp8_custom_op, ): num_processes = 2 if ( @@ -181,7 +236,16 @@ def test_all_reduce_fusion_pass_replace( def run_torch_spawn(fn, nprocs): torch.multiprocessing.spawn( fn, - args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype), + args=( + num_processes, + test_model, + batch_size, + seq_len, + hidden_size, + dtype, + enable_rms_norm_custom_op, + enable_quant_fp8_custom_op, + ), nprocs=nprocs, ) @@ -196,6 +260,8 @@ def all_reduce_fusion_pass_on_test_model( seq_len: int, hidden_size: int, dtype: torch.dtype, + enable_rms_norm_custom_op, + enable_quant_fp8_custom_op, ): current_platform.seed_everything(0) @@ -217,40 +283,50 @@ def all_reduce_fusion_pass_on_test_model( init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=world_size) + custom_ops = [] + if enable_rms_norm_custom_op: + custom_ops.append("+rms_norm") + if enable_quant_fp8_custom_op: + custom_ops.append("+quant_fp8") + vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm", "+quant_fp8"] + mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops ) ) vllm_config.compilation_config.pass_config = PassConfig( enable_fi_allreduce_fusion=True, enable_noop=True ) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) + vllm_config.parallel_config.rank = local_rank # Setup rank for debug path # this is a fake model name to construct the model config # in the vllm_config, it's not really used. - model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" + model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8" vllm_config.model_config = ModelConfig( model=model_name, trust_remote_code=True, dtype=dtype, seed=42 ) + with set_current_vllm_config(vllm_config): + all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) + noop_pass = NoOpEliminationPass(vllm_config) + func_pass = FixFunctionalizationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + backend = TestBackend( + noop_pass, all_reduce_fusion_pass, func_pass, cleanup_pass + ) - all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) - noop_pass = NoOpEliminationPass(vllm_config) - func_pass = FixFunctionalizationPass(vllm_config) - cleanup_pass = PostCleanupPass(vllm_config) - - backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass) - - token_num = batch_size * seq_len - model = test_model_cls(hidden_size, token_num) + token_num = batch_size * seq_len + model = test_model_cls(hidden_size, token_num) - hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) - residual = torch.randn((token_num, hidden_size), requires_grad=False) + hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) - compiled_model = torch.compile(model, backend=backend) - compiled_model(hidden_states, residual) + compiled_model = torch.compile(model, backend=backend) + compiled_model(hidden_states) - assert all_reduce_fusion_pass.matched_count == 1 - backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) - backend.check_after_ops(model.ops_in_model_after()) - del all_reduce_fusion_pass + assert all_reduce_fusion_pass.matched_count == 4, ( + f"{all_reduce_fusion_pass.matched_count=}" + ) + backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) + backend.check_after_ops(model.ops_in_model_after()) + del all_reduce_fusion_pass diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 0f2e3bffbd31..fecb1e2e918f 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -1,26 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy -from typing import Optional import pytest import torch._dynamo from tests.compile.backend import LazyInitPass, TestBackend +from tests.utils import flat_product from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.registry import _Backend from vllm.attention.selector import global_force_attn_backend_context_manager -from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes +from vllm.compilation.matcher_utils import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import ( CacheConfig, CompilationConfig, - CompilationLevel, + CompilationMode, ModelConfig, PassConfig, SchedulerConfig, @@ -29,21 +29,18 @@ ) from vllm.forward_context import get_forward_context, set_forward_context from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8StaticTensorSym, kNvfp4Quant, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.flashinfer import has_flashinfer from vllm.v1.kv_cache_interface import AttentionSpec FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 -# globals needed for string-import custom Dynamo backend field -backend: Optional[TestBackend] = None -backend_unfused: Optional[TestBackend] = None - class AttentionQuantPatternModel(torch.nn.Module): """Base model for AttentionQuantPattern fusion.""" @@ -105,6 +102,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: num_blocks = batch_size * max_blocks backend = self.attn.backend + # TODO(luka) use get_kv_cache_stride_order # Create dummy KV cache for the selected backend if backend == _Backend.ROCM_ATTN: # k/v as 1st dimention @@ -242,26 +240,40 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): ) +MODELS_FP8: list[tuple[str, type]] = [] +MODELS_FP4: list[tuple[str, type]] = [] +HEADS: list[tuple[int, int]] = [] +SPLIT_ATTENTION: list[bool] = [] +BACKENDS_FP8: list[_Backend] = [] +BACKENDS_FP4: list[_Backend] = [] + if current_platform.is_cuda(): - MODELS = [ + HEADS = [(64, 8), (40, 8)] + MODELS_FP8 = [ ( "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", TestAttentionFp8StaticQuantPatternModel, - ), + ) + ] + MODELS_FP4 = [ ( "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", TestAttentionNvfp4QuantPatternModel, - ), + ) ] - HEADS = [(64, 8), (40, 8)] + BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER] + BACKENDS_FP4 = [_Backend.FLASHINFER] + elif current_platform.is_rocm(): - MODELS = [ + HEADS = [(32, 8), (40, 8)] + MODELS_FP8 = [ ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel) ] - HEADS = [(32, 8), (40, 8)] -else: - MODELS = [] - HEADS = [] + BACKENDS = [ + _Backend.ROCM_AITER_UNIFIED_ATTN, + _Backend.ROCM_ATTN, + _Backend.TRITON_ATTN, + ] @pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS) @@ -270,46 +282,36 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): "batch_size", [7, 256, 533] if current_platform.is_cuda() else [8] ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("model_name, model_class", MODELS) @pytest.mark.parametrize( - "backend", - [_Backend.FLASHINFER] - if current_platform.is_cuda() - else [_Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, _Backend.TRITON_ATTN], -) -# TODO(boyuan): test inductor graph partition on rocm -@pytest.mark.parametrize( - "use_inductor_graph_partition", - [False] if current_platform.is_rocm() else [False, True], + "backend, model_name, model_class, custom_ops", + # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 + list(flat_product(BACKENDS_FP8, MODELS_FP8, ["+quant_fp8", "-quant_fp8"])) + # quant_fp4 only has the custom impl + + list(flat_product(BACKENDS_FP4, MODELS_FP4, [""])), ) @pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" ) @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") -@pytest.mark.skipif( - current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)), - reason="On CUDA only test on SM100(Blackwell)", -) -@pytest.mark.skipif( - not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" -) def test_attention_quant_pattern( num_qo_heads: int, num_kv_heads: int, head_size: int, batch_size: int, dtype: torch.dtype, + custom_ops: str, model_name: str, model_class: type[AttentionQuantPatternModel], backend: _Backend, - use_inductor_graph_partition: bool, dist_init, - caplog_vllm, ): """Test AttentionStaticQuantPattern fusion pass""" + if backend == _Backend.FLASHINFER and ( + not current_platform.is_device_capability((10, 0)) or not has_flashinfer() + ): + pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") - if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): - pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + custom_ops_list = custom_ops.split(",") if custom_ops else [] device = torch.device("cuda:0") torch.manual_seed(42) @@ -322,9 +324,8 @@ def test_attention_quant_pattern( ), scheduler_config=SchedulerConfig(max_num_seqs=1024), compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - custom_ops=["+quant_fp8"], - use_inductor_graph_partition=use_inductor_graph_partition, + mode=CompilationMode.VLLM_COMPILE, + custom_ops=custom_ops_list, ), cache_config=CacheConfig(cache_dtype="fp8"), ) @@ -359,8 +360,9 @@ def test_attention_quant_pattern( forward_ctx = get_forward_context() forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size) - # Run model directly without compilation and fusion - result_unfused = model_unfused(q, k, v) + # Run model directly without fusion + # Still compile so query QuantFP8 has closer numerics + result_unfused = torch.compile(model_unfused, fullgraph=True)(q, k, v) # Run model with attn fusion enabled vllm_config.compilation_config.pass_config = PassConfig( @@ -415,14 +417,25 @@ def test_attention_quant_pattern( ) # Check attn fusion support - quant_key = model_class.quant_key + quant_key: QuantKey = model_class.quant_key attn_fusion_supported = [ layer.impl.fused_output_quant_supported(quant_key) for key, layer in vllm_config.compilation_config.static_forward_context.items() ] - if any(attn_fusion_supported): - # Check quantization ops in the graph before and after fusion - test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True) + assert sum(attn_fusion_supported) == len(attn_fusion_supported), ( + "All layers should support attention fusion" + ) + + # Check quantization ops in the graph before and after fusion + quant_op = ( + torch.ops.aten.reciprocal + if "-quant_fp8" in custom_ops_list + else QUANT_OPS[quant_key] + ) + + # Note: for fp8, fully_replaced=False because query quant ops remain in graph. + # Only output quant ops are fused into attention. + test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Quant) # access the underlying `AttnFusionPass` on the `LazyInitPass` assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py new file mode 100644 index 000000000000..d66c60ccb5b2 --- /dev/null +++ b/tests/compile/test_fusions_e2e.py @@ -0,0 +1,308 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import itertools +import logging +from collections.abc import Iterable +from typing import Any, NamedTuple + +import pytest +import regex as re + +from tests.v1.attention.utils import _Backend +from vllm import LLM, SamplingParams +from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig +from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer +from vllm.utils.torch_utils import is_torch_equal_or_newer + +from ..utils import flat_product, multi_gpu_test + + +class ModelBackendTestCase(NamedTuple): + model_name: str + model_kwargs: dict[str, Any] + backend: _Backend + attention_fusions: int + allreduce_fusions: int | None = None + + +MODELS_FP8: list[ModelBackendTestCase] = [] +MODELS_FP4: list[ModelBackendTestCase] = [] +MODELS: list[ModelBackendTestCase] = [] # tp-only + +if current_platform.is_cuda(): + MODELS_FP8 = [ + ModelBackendTestCase( + # Use smaller model for L40s in CI + model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.TRITON_ATTN, + attention_fusions=32, + allreduce_fusions=65, + ), + ModelBackendTestCase( + model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), + backend=_Backend.FLASHINFER, + attention_fusions=48, + allreduce_fusions=96, + ), + ] + + MODELS_FP4 = [ + ModelBackendTestCase( + model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), + backend=_Backend.FLASHINFER, + attention_fusions=48, + allreduce_fusions=96, + ), + ] + + # TP only + MODELS = [ + ModelBackendTestCase( + model_name="meta-llama/Llama-3.1-8B-Instruct", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.TRITON_ATTN, + attention_fusions=0, + allreduce_fusions=65, + ), + ] + +elif current_platform.is_rocm(): + MODELS_FP8 = [ + ModelBackendTestCase( + model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.TRITON_ATTN, + attention_fusions=32, + ), + ModelBackendTestCase( + model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.ROCM_ATTN, + attention_fusions=32, + ), + ModelBackendTestCase( + model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", + model_kwargs=dict(max_model_len=1024), + backend=_Backend.ROCM_AITER_UNIFIED_ATTN, + attention_fusions=32, + ), + ] + +# TODO(luka) test both in nightly +CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"] + + +@pytest.mark.parametrize( + "model_name, model_kwargs, backend, " + "attention_fusions, allreduce_fusions, custom_ops", + # Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8 + list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8)) + # quant_fp4 only has the custom impl + + list(flat_product(MODELS_FP4, [""])), +) +@pytest.mark.parametrize("inductor_graph_partition", [True, False]) +def test_attn_quant( + model_name: str, + model_kwargs: dict[str, Any], + backend: _Backend, + attention_fusions: int, + allreduce_fusions: int, + custom_ops: str, + inductor_graph_partition: bool, + caplog_mp_spawn, + monkeypatch, +): + if backend == _Backend.FLASHINFER and ( + not current_platform.is_device_capability((10, 0)) or not has_flashinfer() + ): + pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer") + if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition requires torch>=2.9") + + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + if inductor_graph_partition: + mode = CUDAGraphMode.FULL_AND_PIECEWISE + splitting_ops: list[str] | None = None + else: + # FIXME: Llama-4-Scout-17B-16E-Instruct-FP8 + FlashInfer + Blackwell end at + # CUDAGraphMode.NONE here because it derives an attention backend that + # does not support full cudagraphs + mode = CUDAGraphMode.FULL_DECODE_ONLY + splitting_ops = [] + + # Disable, compile cache to make sure custom passes run. + # Otherwise, we can't verify fusion happened through the logs. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + # To capture subprocess logs, we need to know whether spawn or fork is used. + # Force spawn as it is more general. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + + compilation_config = CompilationConfig( + # Testing properties + custom_ops=custom_ops_list, + use_inductor_graph_partition=inductor_graph_partition, + cudagraph_mode=mode, + splitting_ops=splitting_ops, + # Common + mode=CompilationMode.VLLM_COMPILE, + pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True), + # Inductor caches custom passes by default as well via uuid + inductor_compile_config={"force_disable_caches": True}, + ) + + with caplog_mp_spawn(logging.DEBUG) as log_holder: + run_model(compilation_config, model_name, **model_kwargs) + + matches = re.findall( + r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", + log_holder.text, + ) + assert len(matches) == 1, log_holder.text + assert int(matches[0]) == attention_fusions + + +# TODO(luka) test both in nightly +CUSTOM_OPS_RMS_NORM = ["-rms_norm"] # , "+rms_norm"] + + +def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: + for op_list in itertools.product(*custom_ops_lists): + yield ",".join(op_list) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "model_name, model_kwargs, backend, " + "attention_fusions, allreduce_fusions, custom_ops", + # Toggle RMSNorm and QuantFP8 for FP8 models + list( + flat_product( + MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM) + ) + ) + # Toggle RMSNorm for FP4 models and unquant models + + list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)), +) +@pytest.mark.parametrize("inductor_graph_partition", [True, False]) +@pytest.mark.skipif( + not current_platform.is_cuda() + or not has_flashinfer() + or not current_platform.has_device_capability(90), + reason="allreduce+rmsnorm fusion requires flashinfer", +) +def test_tp2_attn_quant_allreduce_rmsnorm( + model_name: str, + model_kwargs: dict, + backend: _Backend, + attention_fusions: int, + allreduce_fusions: int, + custom_ops: str, + inductor_graph_partition: bool, + caplog_mp_spawn, + monkeypatch, +): + if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition requires torch>=2.9") + + custom_ops_list = custom_ops.split(",") if custom_ops else [] + + if inductor_graph_partition: + mode = CUDAGraphMode.FULL_AND_PIECEWISE + splitting_ops: list[str] | None = None + else: + mode = CUDAGraphMode.FULL_DECODE_ONLY + splitting_ops = [] + + # Disable, compile cache to make sure custom passes run. + # Otherwise, we can't verify fusion happened through the logs. + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + # To capture subprocess logs, we need to know whether spawn or fork is used. + # Force spawn as it is more general. + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name) + + compilation_config = CompilationConfig( + # Testing properties + use_inductor_graph_partition=inductor_graph_partition, + cudagraph_mode=mode, + custom_ops=custom_ops_list, + splitting_ops=splitting_ops, + # Common + mode=CompilationMode.VLLM_COMPILE, + pass_config=PassConfig( + enable_attn_fusion=True, + enable_noop=True, + enable_fi_allreduce_fusion=True, + ), + # Inductor caches custom passes by default as well via uuid + inductor_compile_config={"force_disable_caches": True}, + ) + + with caplog_mp_spawn(logging.DEBUG) as log_holder: + run_model( + compilation_config, model_name, tensor_parallel_size=2, **model_kwargs + ) + matches = re.findall( + r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes", + log_holder.text, + ) + assert len(matches) == 2, log_holder.text + + assert int(matches[0]) == attention_fusions + assert int(matches[1]) == attention_fusions + + matches = re.findall( + r"collective_fusion.py:\d+] Replaced (\d+) patterns", + log_holder.text, + ) + assert len(matches) == 2, log_holder.text + + assert int(matches[0]) == allreduce_fusions + assert int(matches[1]) == allreduce_fusions + + +def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs): + compilation_config = ( + compile_config + if isinstance(compile_config, CompilationConfig) + else CompilationConfig(mode=compile_config) + ) + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0) + # Allow override from model_kwargs + model_kwargs = {"tensor_parallel_size": 1, **model_kwargs} + model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs} + + # No cudagraphs by default + if compilation_config.cudagraph_mode is None: + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + + llm = LLM( + model=model, + compilation_config=compilation_config, + **model_kwargs, + ) + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/compile/test_noop_elimination.py b/tests/compile/test_noop_elimination.py index fda7f4e3bafa..0ccc1a016162 100644 --- a/tests/compile/test_noop_elimination.py +++ b/tests/compile/test_noop_elimination.py @@ -6,21 +6,29 @@ import vllm from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig +from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig from .backend import TestBackend @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) -@pytest.mark.parametrize("num_tokens", [256, 1024]) +# Important edge case is when `num_tokens == buffer_size` +@pytest.mark.parametrize( + ("num_tokens", "buffer_size"), [(256, 256), (256, 512), (1024, 1024), (1024, 1025)] +) @pytest.mark.parametrize("hidden_size", [64, 4096]) -def test_noop_elimination(dtype, num_tokens, hidden_size): +def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.pos_embed = torch.empty(buffer_size, hidden_size, dtype=dtype) + def forward(self, x): + x += self.pos_embed[: x.shape[0]] # Chain of reshapes y = x.reshape(-1, 128, 32) z = y.reshape(-1, 4096) @@ -42,7 +50,7 @@ def forward(self, x): vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, pass_config=PassConfig(enable_noop=True), ) ) @@ -65,9 +73,10 @@ def forward(self, x): torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) # The no-op reshape and slice should be eliminated. + # The initial slice on the positional embedding should remain. # The chain of reshapes should be fused into a single reshape. assert backend.op_count(torch.ops.aten.reshape.default) == 1 - assert backend.op_count(torch.ops.aten.slice.Tensor) == 0 + assert backend.op_count(torch.ops.aten.slice.Tensor) == 1 assert backend.op_count(torch.ops.aten.slice_scatter.default) == 0 @@ -89,7 +98,7 @@ def forward(self, x): vllm_config = VllmConfig( compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, pass_config=PassConfig(enable_noop=True), ) ) diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py index ac561d2e8f84..1c40c599f748 100644 --- a/tests/compile/test_pass_manager.py +++ b/tests/compile/test_pass_manager.py @@ -7,7 +7,7 @@ from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.compilation.pass_manager import PostGradPassManager -from vllm.config import VllmConfig +from vllm.config import ModelConfig, VllmConfig # dummy custom pass that doesn't inherit @@ -42,7 +42,8 @@ def __call__(self, graph: torch.fx.graph.Graph) -> None: ], ) def test_pass_manager_uuid(callable): - config = VllmConfig() + # Some passes need dtype to be set + config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16)) pass_manager = PostGradPassManager() pass_manager.configure(config) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index afb31cb95be0..e909cf7393ad 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -18,6 +18,8 @@ ModelConfig, PassConfig, VllmConfig, + get_current_vllm_config, + set_current_vllm_config, ) from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( @@ -27,7 +29,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform -from vllm.utils import update_environment_variables +from vllm.utils.system_utils import update_environment_variables from ..utils import multi_gpu_test from .backend import TestBackend @@ -42,9 +44,7 @@ class TestModel(torch.nn.Module): - def __init__( - self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None - ): + def __init__(self, hidden_size=16, intermediate_size=32): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size @@ -95,13 +95,11 @@ def ops_in_model(self): class TestQuantModel(torch.nn.Module): - def __init__( - self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None - ): + def __init__(self, hidden_size=16, intermediate_size=32): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.vllm_config = vllm_config + self.vllm_config = get_current_vllm_config() self.gate_proj = torch.nn.Parameter( torch.empty((intermediate_size, hidden_size)), requires_grad=False ) @@ -266,68 +264,84 @@ def sequence_parallelism_pass_on_test_model( initialize_model_parallel(tensor_model_parallel_size=world_size) # configure vllm config for SequenceParallelismPass - vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig( + compilation_config = CompilationConfig( pass_config=PassConfig( enable_sequence_parallelism=True, enable_fusion=enable_fusion, enable_noop=True, ) ) # NoOp needed for fusion - vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) + device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config # in the vllm_config, it's not really used. - model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" - vllm_config.model_config = ModelConfig( + model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8" + model_config = ModelConfig( model=model_name, trust_remote_code=True, dtype=dtype, seed=42 ) - noop_pass = NoOpEliminationPass(vllm_config) - sequence_parallelism_pass = SequenceParallelismPass(vllm_config) - func_pass = FixFunctionalizationPass(vllm_config) - cleanup_pass = PostCleanupPass(vllm_config) - - passes_for_backend: list[VllmInductorPass] = [noop_pass, sequence_parallelism_pass] + vllm_config = VllmConfig( + model_config=model_config, + device_config=device_config, + compilation_config=compilation_config, + ) - if enable_fusion: - fusion_pass = RMSNormQuantFusionPass(vllm_config) - passes_for_backend.append(fusion_pass) + with set_current_vllm_config(vllm_config): + noop_pass = NoOpEliminationPass(vllm_config) + sequence_parallelism_pass = SequenceParallelismPass(vllm_config) + func_pass = FixFunctionalizationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + assert ( + sequence_parallelism_pass.compilation_config.splitting_ops + == vllm_config.compilation_config.splitting_ops + ) + assert ( + sequence_parallelism_pass.compilation_config.use_inductor_graph_partition + == vllm_config.compilation_config.use_inductor_graph_partition + ) + passes_for_backend: list[VllmInductorPass] = [ + noop_pass, + sequence_parallelism_pass, + ] - passes_for_backend.append(cleanup_pass) + if enable_fusion: + fusion_pass = RMSNormQuantFusionPass(vllm_config) + passes_for_backend.append(fusion_pass) - backend_no_func = TestBackend(*passes_for_backend) - backend_func = TestBackend(*passes_for_backend, func_pass) + passes_for_backend.append(cleanup_pass) - model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config) + backend_no_func = TestBackend(*passes_for_backend) + backend_func = TestBackend(*passes_for_backend, func_pass) - hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) - residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + model = test_model_cls(hidden_size, hidden_size * 2) - compiled_model_no_func = torch.compile(model, backend=backend_no_func) - compiled_model_no_func(hidden_states, residual) - compiled_model_func = torch.compile(model, backend=backend_func) - compiled_model_func(hidden_states, residual) + hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) - assert sequence_parallelism_pass.matched_count == 1 + compiled_model_no_func = torch.compile(model, backend=backend_no_func) + compiled_model_no_func(hidden_states, residual) + compiled_model_func = torch.compile(model, backend=backend_func) + compiled_model_func(hidden_states, residual) - # In pre-nodes, all reduce should be there, - # reduce scatter and all gather should not - backend_no_func.check_before_ops(model.ops_in_model_before()) + assert sequence_parallelism_pass.matched_count == 1 - # In post-nodes, reduce scatter and all gather should be there, - # all reduce should not - backend_no_func.check_after_ops(model.ops_in_model_after()) + # In pre-nodes, all reduce should be there, + # reduce scatter and all gather should not + backend_no_func.check_before_ops(model.ops_in_model_before()) - # check if the functionalization pass is applied - for op in model.ops_in_model(): - find_auto_fn(backend_no_func.graph_post_pass.nodes, op) - assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None + # In post-nodes, reduce scatter and all gather should be there, + # all reduce should not + backend_no_func.check_after_ops(model.ops_in_model_after()) - # make sure the ops were all de-functionalized - found = dict() - for node in backend_func.graph_post_pass.nodes: + # check if the functionalization pass is applied for op in model.ops_in_model(): - if is_func(node, op): - found[op] = True - assert all(found[op] for op in model.ops_in_model()) + find_auto_fn(backend_no_func.graph_post_pass.nodes, op) + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None + + # make sure the ops were all de-functionalized + found = dict() + for node in backend_func.graph_post_pass.nodes: + for op in model.ops_in_model(): + if is_func(node, op): + found[op] = True + assert all(found[op] for op in model.ops_in_model()) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 16a4271655ef..0ddb82b7c3fc 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import cast +import itertools import pytest import torch @@ -16,7 +16,13 @@ from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import CompilationConfig, PassConfig, VllmConfig +from vllm.config import ( + CompilationConfig, + CompilationMode, + PassConfig, + VllmConfig, + set_current_vllm_config, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -25,7 +31,7 @@ ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, - cutlass_fp8_supported, + maybe_create_device_identity, ) from vllm.platforms import current_platform @@ -54,6 +60,8 @@ def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR, ) + self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() + self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() def forward(self, x): y = self.silu_and_mul(x) @@ -61,7 +69,14 @@ def forward(self, x): return x2 def ops_in_model_before(self): - return [SILU_MUL_OP, QUANT_OPS[kFp8StaticTensorSym]] + return [ + SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul, + ( + QUANT_OPS[kFp8StaticTensorSym] + if self.enable_quant_fp8_custom_op + else torch.ops.aten.reciprocal + ), + ] def ops_in_model_after(self): return [FUSED_OPS[kFp8StaticTensorSym]] @@ -77,6 +92,7 @@ def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs): assert silu_and_mul_nvfp4_quant_supported self.silu_and_mul = SiluAndMul() + self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() # create nvfp4 weight w = torch.rand((hidden_size, hidden_size)) @@ -101,7 +117,10 @@ def forward(self, x): return out def ops_in_model_before(self): - return [SILU_MUL_OP, QUANT_OPS[kNvfp4Quant]] + return [ + SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul, + QUANT_OPS[kNvfp4Quant], + ] def ops_in_model_after(self): return [FUSED_OPS[kNvfp4Quant]] @@ -110,67 +129,80 @@ def ops_in_model_after(self): @pytest.mark.parametrize("num_tokens", [32, 64]) @pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False]) @pytest.mark.parametrize( - "model_class", - cast( - list[type], - [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel] - if is_nvfp4_supported() - else [TestSiluMulFp8QuantModel], - ), + "model_class, enable_quant_fp8_custom_op, cuda_force_torch", + list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False])) + + [(TestSiluMulNvfp4QuantModel, False, False)], ) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. -@pytest.mark.parametrize( - "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True] -) @pytest.mark.skipif( envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm" ) def test_fusion_silu_and_mul_quant( - num_tokens, hidden_size, dtype, model_class, cuda_force_torch + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel], + enable_silu_mul_custom_op: bool, + enable_quant_fp8_custom_op: bool, + cuda_force_torch: bool, ): - if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch: - pytest.skip("Duplicate tests for NVFP4") + if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported(): + pytest.skip("NVFP4 is not supported on this GPU.") torch.set_default_device("cuda") torch.set_default_dtype(dtype) + maybe_create_device_identity() x = torch.rand(num_tokens, hidden_size * 2) # Reshape pass is needed for the fusion pass to work - config = VllmConfig() - config.compilation_config = CompilationConfig( - pass_config=PassConfig(enable_fusion=True, enable_noop=True) + custom_ops = [] + if enable_silu_mul_custom_op: + custom_ops.append("+silu_and_mul") + if enable_quant_fp8_custom_op: + custom_ops.append("+quant_fp8") + config = VllmConfig( + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=custom_ops, + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + ), ) - fusion_pass = ActivationQuantFusionPass(config) - passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)] - backend = TestBackend(*passes) - model = model_class(hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x) + with set_current_vllm_config(config): + fusion_pass = ActivationQuantFusionPass(config) - # First dimension dynamic - torch._dynamo.mark_dynamic(x, 0) + passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)] + backend = TestBackend(*passes) + model = model_class( + hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x + ) - result = model(x) + # First dimension dynamic + torch._dynamo.mark_dynamic(x, 0) - model2 = torch.compile(model, backend=backend) - result2 = model2(x) + result = model(x) - # Check that it gives the same answer - if model_class == TestSiluMulFp8QuantModel: - atol, rtol = 1e-3, 1e-3 - elif model_class == TestSiluMulNvfp4QuantModel: - atol, rtol = 1e-1, 1e-1 + model2 = torch.compile(model, backend=backend) + result2 = model2(x) - torch.testing.assert_close( - result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol - ) + # Check that it gives the same answer + if model_class == TestSiluMulFp8QuantModel: + atol, rtol = 1e-3, 1e-3 + elif model_class == TestSiluMulNvfp4QuantModel: + atol, rtol = 1e-1, 1e-1 + + torch.testing.assert_close( + result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol + ) - assert fusion_pass.matched_count == 1 + assert fusion_pass.matched_count == 1 - # In pre-nodes, quant op should be present and fused kernels should not - backend.check_before_ops(model.ops_in_model_before()) + # In pre-nodes, quant op should be present and fused kernels should not + backend.check_before_ops(model.ops_in_model_before()) - # In post-nodes, fused kernels should be present and quant op should not - backend.check_after_ops(model.ops_in_model_after()) + # In post-nodes, fused kernels should be present and quant op should not + backend.check_after_ops(model.ops_in_model_after()) diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py index 34db5a999cbd..da0afd9eaa49 100644 --- a/tests/compile/test_wrapper.py +++ b/tests/compile/test_wrapper.py @@ -1,16 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import CompilationLevel +from vllm.config import CompilationMode class MyMod(torch.nn.Module): - def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None): if cache is not None: return x + cache return x * 2 @@ -21,14 +20,14 @@ def __init__(self, model): self.model = model compiled_callable = torch.compile(self.forward, backend="eager") super().__init__( - compiled_callable, compilation_level=CompilationLevel.DYNAMO_ONCE + compiled_callable, compilation_mode=CompilationMode.DYNAMO_TRACE_ONCE ) - def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None): # this is the function to be compiled return self.model(x, cache) - def __call__(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + def __call__(self, x: torch.Tensor, cache: torch.Tensor | None = None): # let torch.compile compile twice if len(self.compiled_codes) == 2: dispatch_id = 0 if cache is None else 1 diff --git a/tests/config/test_multimodal_config.py b/tests/config/test_multimodal_config.py new file mode 100644 index 000000000000..b1a09d88ed9d --- /dev/null +++ b/tests/config/test_multimodal_config.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.attention.backends.registry import _Backend +from vllm.config.multimodal import MultiModalConfig + + +def test_mm_encoder_attn_backend_str_conversion(): + config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN") + assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN + + +def test_mm_encoder_attn_backend_invalid(): + with pytest.raises(ValueError): + MultiModalConfig(mm_encoder_attn_backend="not_a_backend") + + +def test_mm_encoder_attn_backend_hash_updates(): + base_hash = MultiModalConfig().compute_hash() + overridden_hash = MultiModalConfig( + mm_encoder_attn_backend=_Backend.FLASH_ATTN + ).compute_hash() + assert base_hash != overridden_hash diff --git a/tests/conftest.py b/tests/conftest.py index 4713e1238596..ec0179b9cd5a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# ruff: noqa +import contextlib +import pathlib +from copy import deepcopy from tblib import pickling_support +# ruff: noqa + # Install support for pickling exceptions so that we can nicely propagate # failures from tests running in a subprocess. # This should be run before any custom exception subclasses are defined. @@ -21,7 +24,7 @@ from collections.abc import Generator from contextlib import nullcontext from enum import Enum -from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast +from typing import Any, Callable, TypedDict, TypeVar, cast import numpy as np import pytest @@ -40,7 +43,7 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs -from vllm import LLM, SamplingParams +from vllm import LLM, SamplingParams, envs from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset @@ -57,7 +60,8 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.transformers_utils.utils import maybe_model_redirect -from vllm.utils import is_list_of, set_default_torch_num_threads +from vllm.utils.collection_utils import is_list_of +from vllm.utils.torch_utils import set_default_torch_num_threads logger = init_logger(__name__) @@ -68,7 +72,7 @@ _M = TypeVar("_M") -_PromptMultiModalInput = Union[list[_M], list[list[_M]]] +_PromptMultiModalInput = list[_M] | list[list[_M]] PromptImageInput = _PromptMultiModalInput[Image.Image] PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]] @@ -267,7 +271,7 @@ def get_default_device(self): return "cpu" if current_platform.is_cpu() else current_platform.device_type - def wrap_device(self, x: _T, device: Optional[str] = None) -> _T: + def wrap_device(self, x: _T, device: str | None = None) -> _T: if x is None or isinstance(x, (bool,)): return x @@ -287,14 +291,14 @@ def __init__( model_name: str, dtype: str = "auto", *, - model_kwargs: Optional[dict[str, Any]] = None, + model_kwargs: dict[str, Any] | None = None, trust_remote_code: bool = True, is_sentence_transformer: bool = False, is_cross_encoder: bool = False, skip_tokenizer_init: bool = False, auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM, # Set this to avoid hanging issue - default_torch_num_threads: Optional[int] = None, + default_torch_num_threads: int | None = None, ) -> None: init_ctx = ( nullcontext() @@ -319,7 +323,7 @@ def _init( model_name: str, dtype: str = "auto", *, - model_kwargs: Optional[dict[str, Any]] = None, + model_kwargs: dict[str, Any] | None = None, trust_remote_code: bool = True, is_sentence_transformer: bool = False, is_cross_encoder: bool = False, @@ -334,7 +338,7 @@ def _init( trust_remote_code=trust_remote_code, ) self.device = self.get_default_device() - self.dtype = torch_dtype = _get_and_verify_dtype( + self.dtype = dtype = _get_and_verify_dtype( self.model_name, self.config, dtype=dtype, @@ -342,7 +346,7 @@ def _init( ) model_kwargs = model_kwargs if model_kwargs is not None else {} - model_kwargs.setdefault("torch_dtype", torch_dtype) + model_kwargs.setdefault("dtype", dtype) if is_sentence_transformer: # Lazy init required for AMD CI @@ -388,7 +392,7 @@ def _init( if not skip_tokenizer_init: self.tokenizer = AutoTokenizer.from_pretrained( model_name, - torch_dtype=torch_dtype, + dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -398,7 +402,7 @@ def _init( self.processor = AutoProcessor.from_pretrained( model_name, - torch_dtype=torch_dtype, + dtype=dtype, trust_remote_code=trust_remote_code, ) if skip_tokenizer_init: @@ -406,11 +410,11 @@ def _init( def get_inputs( self, - prompts: Union[list[str], list[list[int]]], - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, - ) -> list[Union[BatchFeature, BatchEncoding, dict[str, torch.Tensor]]]: + prompts: list[str] | list[list[int]], + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, + ) -> list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]]: if images is not None: assert len(prompts) == len(images) @@ -420,9 +424,7 @@ def get_inputs( if audios is not None: assert len(prompts) == len(audios) - all_inputs: list[ - Union[BatchFeature, BatchEncoding, dict[str, torch.Tensor]] - ] = [] + all_inputs: list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]] = [] for i, prompt in enumerate(prompts): if isinstance(prompt, str): processor_kwargs: dict[str, Any] = { @@ -494,10 +496,10 @@ def classify(self, prompts: list[str]) -> list[str]: def generate( self, - prompts: Union[list[str], list[list[int]]], - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, + prompts: list[str] | list[list[int]], + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, **kwargs: Any, ) -> list[tuple[list[list[int]], list[str]]]: all_inputs = self.get_inputs( @@ -522,11 +524,11 @@ def generate( def generate_greedy( self, - prompts: Union[list[str], list[list[int]]], + prompts: list[str] | list[list[int]], max_tokens: int, - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, **kwargs: Any, ) -> list[tuple[list[int], str]]: outputs = self.generate( @@ -546,9 +548,9 @@ def generate_beam_search( prompts: list[str], beam_width: int, max_tokens: int, - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, ) -> list[tuple[list[list[int]], list[str]]]: outputs = self.generate( prompts, @@ -574,9 +576,9 @@ def generate_greedy_logprobs( self, prompts: list[str], max_tokens: int, - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, **kwargs: Any, ) -> list[list[torch.Tensor]]: all_inputs = self.get_inputs( @@ -624,7 +626,7 @@ def _hidden_states_to_seq_logprobs( def _hidden_states_to_logprobs( self, hidden_states: tuple[tuple[torch.Tensor, ...], ...], - num_logprobs: Optional[int], + num_logprobs: int | None, ) -> tuple[list[dict[int, float]], int]: seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states) output_len = len(hidden_states) @@ -652,10 +654,10 @@ def generate_greedy_logprobs_limit( self, prompts: list[str], max_tokens: int, - num_logprobs: Optional[int], - images: Optional[PromptImageInput] = None, - audios: Optional[PromptAudioInput] = None, - videos: Optional[PromptVideoInput] = None, + num_logprobs: int | None, + images: PromptImageInput | None = None, + audios: PromptAudioInput | None = None, + videos: PromptVideoInput | None = None, **kwargs: Any, ) -> list[TokensTextLogprobs]: all_inputs = self.get_inputs( @@ -734,20 +736,20 @@ def __init__( model_name: str, runner: RunnerOption = "auto", convert: ConvertOption = "auto", - tokenizer_name: Optional[str] = None, + tokenizer_name: str | None = None, tokenizer_mode: str = "auto", trust_remote_code: bool = True, - seed: Optional[int] = 0, - max_model_len: Optional[int] = 1024, + seed: int | None = 0, + max_model_len: int | None = 1024, dtype: str = "auto", disable_log_stats: bool = True, tensor_parallel_size: int = 1, block_size: int = 16 if not torch.xpu.is_available() else 64, - enable_chunked_prefill: Optional[bool] = False, + enable_chunked_prefill: bool | None = False, swap_space: int = 4, - enforce_eager: Optional[bool] = False, + enforce_eager: bool | None = False, # Set this to avoid hanging issue - default_torch_num_threads: Optional[int] = None, + default_torch_num_threads: int | None = None, **kwargs, ) -> None: init_ctx = ( @@ -785,10 +787,10 @@ def __init__( def get_inputs( self, - prompts: Union[list[str], list[torch.Tensor], list[list[int]]], - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, + prompts: list[str] | list[torch.Tensor] | list[list[int]], + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, ) -> list[dict[str, Any]]: if any( x is not None and len(x) != len(prompts) for x in [images, videos, audios] @@ -824,11 +826,11 @@ def get_inputs( def generate( self, - prompts: Union[list[str], list[torch.Tensor], list[list[int]]], + prompts: list[str] | list[torch.Tensor] | list[list[int]], sampling_params: SamplingParams, - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, **kwargs: Any, ) -> list[tuple[list[list[int]], list[str]]]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) @@ -871,11 +873,11 @@ def generate_w_logprobs( self, prompts: list[str], sampling_params: SamplingParams, - images: Optional[PromptImageInput] = None, - audios: Optional[PromptAudioInput] = None, - videos: Optional[PromptVideoInput] = None, + images: PromptImageInput | None = None, + audios: PromptAudioInput | None = None, + videos: PromptVideoInput | None = None, **kwargs: Any, - ) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]: + ) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) req_outputs = self.llm.generate( @@ -894,11 +896,11 @@ def generate_w_logprobs( def generate_greedy( self, - prompts: Union[list[str], list[torch.Tensor], list[list[int]]], + prompts: list[str] | list[torch.Tensor] | list[list[int]], max_tokens: int, - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, **kwargs: Any, ) -> list[tuple[list[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) @@ -916,15 +918,15 @@ def generate_greedy_logprobs( self, prompts: list[str], max_tokens: int, - num_logprobs: Optional[int], - num_prompt_logprobs: Optional[int] = None, - images: Optional[PromptImageInput] = None, - audios: Optional[PromptAudioInput] = None, - videos: Optional[PromptVideoInput] = None, - stop_token_ids: Optional[list[int]] = None, - stop: Optional[list[str]] = None, + num_logprobs: int | None, + num_prompt_logprobs: int | None = None, + images: PromptImageInput | None = None, + audios: PromptAudioInput | None = None, + videos: PromptVideoInput | None = None, + stop_token_ids: list[int] | None = None, + stop: list[str] | None = None, **kwargs: Any, - ) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]: + ) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]: greedy_logprobs_params = SamplingParams( temperature=0.0, max_tokens=max_tokens, @@ -957,7 +959,7 @@ def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]: perplexities = [] for output in outputs: output = cast(TokensTextLogprobsPromptLogprobs, output) - token_datas = cast(list[Optional[dict[int, Logprob]]], output[3]) + token_datas = cast(list[dict[int, Logprob] | None], output[3]) assert token_datas[0] is None token_log_probs = [] for token_data in token_datas[1:]: @@ -976,10 +978,10 @@ def generate_beam_search( prompts: list[str], beam_width: int, max_tokens: int, - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, - concurrency_limit: Optional[int] = None, + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, + concurrency_limit: int | None = None, ) -> list[tuple[list[list[int]], list[str]]]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) @@ -1002,9 +1004,9 @@ def classify(self, prompts: list[str]) -> list[list[float]]: def embed( self, prompts: list[str], - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, *args, **kwargs, ) -> list[list[float]]: @@ -1013,8 +1015,12 @@ def embed( req_outputs = self.llm.embed(inputs, *args, **kwargs) return [req_output.outputs.embedding for req_output in req_outputs] - def encode(self, prompts: list[str]) -> list[list[float]]: - req_outputs = self.llm.encode(prompts) + def token_embed(self, prompts: list[str]) -> list[list[float]]: + req_outputs = self.llm.encode(prompts, pooling_task="token_embed") + return [req_output.outputs.data for req_output in req_outputs] + + def token_classify(self, prompts: list[str]) -> list[list[float]]: + req_outputs = self.llm.encode(prompts, pooling_task="token_classify") return [req_output.outputs.data for req_output in req_outputs] def reward(self, prompts: list[str]) -> list[list[float]]: @@ -1023,8 +1029,8 @@ def reward(self, prompts: list[str]) -> list[list[float]]: def score( self, - text_1: Union[str, list[str]], - text_2: Union[str, list[str]], + text_1: list[str] | str, + text_2: list[str] | str, *args, **kwargs, ) -> list[float]: @@ -1067,6 +1073,101 @@ def caplog_vllm(temporary_enable_log_propagate, caplog): yield caplog +@pytest.fixture() +def caplog_mp_fork(): + """ + This fixture enables capturing logs from a forked MP subprocess. + It should be used in conjunction with caplog_vllm. + + By default, subprocess logs do not go through the parent process. + We instead create a queue listener in the parent process which + forwards logs to the logger's other handlers, and add a QueueHandler + to the root logger. Forked subprocesses will inherit the root logger + and pass their messages to the queue, which the listener will forward + to the root logger, which can be captured by caplog. + + Note that this workaround only works for fork; with spawn, the subprocess + reinitializes logging and does not automatically inherit the queue. + We'd have to manually pass the queue to the subprocess at the spawn point. + See caplog_mp_spawn below. + """ + + @contextlib.contextmanager + def ctx(): + import logging.handlers + import multiprocessing as mp + + logger_queue: mp.Queue[logging.LogRecord] = mp.Queue() + logger = logging.getLogger() + handlers = logger.handlers + + # The listener works on a background thread, not inherited by the child. + queue_listener = logging.handlers.QueueListener(logger_queue, *handlers) + queue_listener.start() + + # Add queue handler after creating the listener to avoid cycle + logger.addHandler(logging.handlers.QueueHandler(logger_queue)) + yield + queue_listener.stop() + + return ctx + + +class LogHolder: + def __init__(self): + self.text = None + + +@pytest.fixture() +def caplog_mp_spawn(tmp_path, monkeypatch): + """ + This fixture enables capturing logs from a forked MP subprocess. + It does not require caplog_vllm (but it only contains logs from the child). + + By default, subprocess logs do not go through the parent process. + We instead add a FileHandler to the config so the spawned child process + writes its logs to a temp file. + In the parent, we read the file and return the contents. + + Note: this method could be extended to fork by either reconfiguring logging + in the parent or using a SocketHandler: + https://docs.python.org/3/howto/logging-cookbook.html#sending-and-receiving-logging-events-across-a-network # noqa: E501 + """ + + @contextlib.contextmanager + def ctx(level: int | str): + from vllm.logger import DEFAULT_LOGGING_CONFIG + + config_path = tmp_path / "vllm_logging_config.json" + log_path = tmp_path / "vllm.log" + log_holder = LogHolder() + + config = deepcopy(DEFAULT_LOGGING_CONFIG) + if envs.VLLM_LOGGING_CONFIG_PATH: + path = pathlib.Path(envs.VLLM_LOGGING_CONFIG_PATH) + assert path.exists() + config = json.loads(path.read_text()) + + config["loggers"]["vllm"]["handlers"] += ["vllm_file"] + config["handlers"]["vllm_file"] = { + "class": "logging.FileHandler", + "formatter": "vllm", + "level": level, + "filename": log_path.as_posix(), + } + + config_path.write_text(json.dumps(config)) + + with monkeypatch.context() as monkeypatch_ctx: + monkeypatch_ctx.setenv("VLLM_LOGGING_CONFIG_PATH", config_path.as_posix()) + monkeypatch_ctx.setenv("VLLM_CONFIGURE_LOGGING", "1") + yield log_holder + + log_holder.text = log_path.read_text() + + return ctx + + @pytest.fixture(scope="session") def num_gpus_available(): """Get number of GPUs without initializing the CUDA context @@ -1226,8 +1327,8 @@ def _find_free_port() -> int: class LocalAssetServer: address: str port: int - server: Optional[http.server.ThreadingHTTPServer] - thread: Optional[threading.Thread] + server: http.server.ThreadingHTTPServer | None + thread: threading.Thread | None def __init__(self, address: str = "127.0.0.1") -> None: self.address = address diff --git a/tests/detokenizer/test_stop_strings.py b/tests/detokenizer/test_stop_strings.py index d59b394393e3..6b829c261035 100644 --- a/tests/detokenizer/test_stop_strings.py +++ b/tests/detokenizer/test_stop_strings.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Any import pytest @@ -15,8 +15,8 @@ def _test_stopping( llm: LLM, expected_output: str, expected_reason: Any, - stop: Optional[list[str]] = None, - stop_token_ids: Optional[list[int]] = None, + stop: list[str] | None = None, + stop_token_ids: list[int] | None = None, include_in_output: bool = False, ) -> None: output = llm.generate( diff --git a/tests/distributed/conftest.py b/tests/distributed/conftest.py index 47ceb45057c9..9c146a3323d9 100644 --- a/tests/distributed/conftest.py +++ b/tests/distributed/conftest.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random -from typing import Optional, Union import msgspec import msgspec.msgpack @@ -78,8 +77,8 @@ class MockSubscriber: def __init__( self, - pub_endpoints: Union[str, list[str]], - replay_endpoints: Optional[Union[str, list[str]]] = None, + pub_endpoints: str | list[str], + replay_endpoints: str | list[str] | None = None, topic: str = "", decode_type=SampleBatch, ): @@ -111,7 +110,7 @@ def __init__( self.last_seq = -1 self.decoder = msgspec.msgpack.Decoder(type=decode_type) - def receive_one(self, timeout=1000) -> Union[tuple[int, SampleBatch], None]: + def receive_one(self, timeout=1000) -> tuple[int, SampleBatch] | None: """Receive a single message with timeout""" if not self.sub.poll(timeout): return None diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index c61c4584d837..ba80ee6fb83b 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -5,9 +5,8 @@ Run `pytest tests/distributed/test_comm_ops.py`. """ -from __future__ import annotations - -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import pytest import ray diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index 53fc9957b910..5495640af07e 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -11,11 +11,11 @@ import json import os from dataclasses import dataclass -from typing import Literal, NamedTuple, Optional +from typing import Literal, NamedTuple import pytest -from vllm.config import RunnerOption +from vllm.config.model import RunnerOption from vllm.logger import init_logger from ..models.registry import HF_EXAMPLE_MODELS @@ -36,7 +36,7 @@ class ParallelSetup(NamedTuple): class CPTestOptions(NamedTuple): multi_node_only: bool - load_format: Optional[str] = None + load_format: str | None = None @dataclass @@ -54,7 +54,7 @@ def detailed( dcp_base: int = 1, multi_node_only: bool = False, runner: RunnerOption = "auto", - load_format: Optional[str] = None, + load_format: str | None = None, ): parallel_setups = [] for eager_mode_val in [False]: @@ -204,17 +204,21 @@ def _compare_cp_with_tp( CP_TEXT_GENERATION_MODELS = { - # [MLA attention only] "deepseek-ai/DeepSeek-V2-Lite-Chat": [ CPTestSettings.detailed(), CPTestSettings.detailed(tp_base=2), ], + "bigcode/gpt_bigcode-santacoder": [ + CPTestSettings.detailed(), + CPTestSettings.detailed(tp_base=2), + ], } CP_TEST_MODELS = [ # TODO support other models # [LANGUAGE GENERATION] "deepseek-ai/DeepSeek-V2-Lite-Chat", + "bigcode/gpt_bigcode-santacoder", ] diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py index 7ca3d3d27b56..7b45ae82c72d 100644 --- a/tests/distributed/test_eplb_execute.py +++ b/tests/distributed/test_eplb_execute.py @@ -15,7 +15,7 @@ get_tp_group, init_distributed_environment, ) -from vllm.utils import update_environment_variables +from vllm.utils.system_utils import update_environment_variables def distributed_run(fn, world_size): diff --git a/tests/distributed/test_events.py b/tests/distributed/test_events.py index f06f6771a4a0..f17b7997c588 100644 --- a/tests/distributed/test_events.py +++ b/tests/distributed/test_events.py @@ -263,3 +263,52 @@ def test_data_parallel_rank_tagging(publisher_config): pub_1.shutdown() sub_0.close() sub_1.close() + + +def test_event_publisher_factory(): + """Test event publisher factory creation behavior under different configurations""" + from vllm.config.kv_events import KVEventsConfig + from vllm.distributed.kv_events import ZmqEventPublisher + + # test config is None + publisher = EventPublisherFactory.create(None, DP_RANK) + assert isinstance(publisher, NullEventPublisher) + publisher.shutdown() + + # test disable kv cache events + config = KVEventsConfig( + enable_kv_cache_events=False, + publisher="zmq", # Even if zmq is specified, should return NullEventPublisher + endpoint="tcp://localhost:5557", + ) + publisher = EventPublisherFactory.create(config, DP_RANK) + assert isinstance(publisher, NullEventPublisher) + publisher.shutdown() + + # test zmq publisher + config = KVEventsConfig( + enable_kv_cache_events=True, + publisher="zmq", + endpoint="inproc://test-factory-true", + ) + publisher = EventPublisherFactory.create(config, DP_RANK) + assert isinstance(publisher, ZmqEventPublisher) + publisher.shutdown() + + # test unknown publisher + with pytest.raises(ValueError, match="Input should be"): + KVEventsConfig( + enable_kv_cache_events=True, + publisher="unknown_publisher", + endpoint="tcp://localhost:5557", + ) + + # test publisher not specified + config = KVEventsConfig( + enable_kv_cache_events=True, + # publisher not specified, should default to "zmq" + endpoint="tcp://localhost:5557", + ) + publisher = EventPublisherFactory.create(config, DP_RANK) + assert isinstance(publisher, ZmqEventPublisher) + publisher.shutdown() diff --git a/tests/distributed/test_expert_parallel.py b/tests/distributed/test_expert_parallel.py index 94f0ece4971b..0228d42a76a0 100644 --- a/tests/distributed/test_expert_parallel.py +++ b/tests/distributed/test_expert_parallel.py @@ -2,11 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Literal, NamedTuple, Optional +from typing import Literal, NamedTuple import pytest -from vllm.config import RunnerOption +from vllm.config.model import RunnerOption from vllm.logger import init_logger from ..utils import compare_two_settings, create_new_process_for_each_test @@ -22,9 +22,9 @@ class ParallelSetup(NamedTuple): class EPTestOptions(NamedTuple): trust_remote_code: bool - tokenizer_mode: Optional[str] - load_format: Optional[str] = None - hf_overrides: Optional[str] = None + tokenizer_mode: str | None + load_format: str | None = None + hf_overrides: str | None = None @dataclass @@ -40,9 +40,9 @@ def detailed( tp_base: int = 2, runner: RunnerOption = "auto", trust_remote_code: bool = False, - tokenizer_mode: Optional[str] = None, - load_format: Optional[str] = None, - hf_overrides: Optional[str] = None, + tokenizer_mode: str | None = None, + load_format: str | None = None, + hf_overrides: str | None = None, ): return EPTestSettings( parallel_setups=[ @@ -72,9 +72,9 @@ def fast( tp_base: int = 2, runner: RunnerOption = "auto", trust_remote_code: bool = False, - tokenizer_mode: Optional[str] = None, - load_format: Optional[str] = None, - hf_overrides: Optional[str] = None, + tokenizer_mode: str | None = None, + load_format: str | None = None, + hf_overrides: str | None = None, ): return EPTestSettings( parallel_setups=[ diff --git a/tests/distributed/test_expert_placement.py b/tests/distributed/test_expert_placement.py index cb9c8f507404..8b3a64b9c134 100644 --- a/tests/distributed/test_expert_placement.py +++ b/tests/distributed/test_expert_placement.py @@ -85,7 +85,7 @@ def test_expert_placement_various_sizes(expert_placement_strategy, world_size): else: expected_test_local = base_experts - test_local_experts, test_expert_map = determine_expert_map( + test_local_experts, test_expert_map, _ = determine_expert_map( ep_size=test_ep_size, ep_rank=ep_rank, global_num_experts=test_global_experts, @@ -116,7 +116,7 @@ def test_expert_placement_edge_cases(expert_placement_strategy, world_size): """Test edge cases for round_robin expert placement.""" # Test case 1: ep_size = 1 (should return None for expert_map) - local_num_experts, expert_map = determine_expert_map( + local_num_experts, expert_map, _ = determine_expert_map( ep_size=1, ep_rank=0, global_num_experts=8, @@ -217,7 +217,7 @@ def test_determine_expert_map_comprehensive(): expected_local, expected_map_pattern, ) in test_cases: - local_num_experts, expert_map = determine_expert_map( + local_num_experts, expert_map, _ = determine_expert_map( ep_size=ep_size, ep_rank=ep_rank, global_num_experts=global_num_experts, diff --git a/tests/distributed/test_multi_node_assignment.py b/tests/distributed/test_multi_node_assignment.py index 8d818edbb3bd..5d3f524f4d2f 100644 --- a/tests/distributed/test_multi_node_assignment.py +++ b/tests/distributed/test_multi_node_assignment.py @@ -18,8 +18,8 @@ from vllm import initialize_ray_cluster from vllm.config import ParallelConfig -from vllm.executor.ray_utils import _wait_until_pg_removed -from vllm.utils import get_ip +from vllm.utils.network_utils import get_ip +from vllm.v1.executor.ray_utils import _wait_until_pg_removed VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" diff --git a/tests/distributed/test_nccl_symm_mem_allreduce.py b/tests/distributed/test_nccl_symm_mem_allreduce.py index 40dcf7567c92..eeb74bdf5357 100644 --- a/tests/distributed/test_nccl_symm_mem_allreduce.py +++ b/tests/distributed/test_nccl_symm_mem_allreduce.py @@ -23,7 +23,7 @@ initialize_model_parallel, ) from vllm.platforms import current_platform -from vllm.utils import update_environment_variables +from vllm.utils.system_utils import update_environment_variables torch.manual_seed(42) random.seed(44) diff --git a/tests/distributed/test_node_count.py b/tests/distributed/test_node_count.py index b48c025aa1a2..34e10084095a 100644 --- a/tests/distributed/test_node_count.py +++ b/tests/distributed/test_node_count.py @@ -7,7 +7,7 @@ from vllm.distributed.parallel_state import _node_count from vllm.distributed.utils import StatelessProcessGroup -from vllm.utils import get_ip, get_open_port +from vllm.utils.network_utils import get_ip, get_open_port if __name__ == "__main__": dist.init_process_group(backend="gloo") diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 43f0c9dd1a85..0ab94d30858f 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -11,7 +11,7 @@ import json import os from dataclasses import dataclass -from typing import Literal, NamedTuple, Optional +from typing import Literal, NamedTuple import pytest @@ -35,7 +35,7 @@ class ParallelSetup(NamedTuple): class PPTestOptions(NamedTuple): multi_node_only: bool - load_format: Optional[str] = None + load_format: str | None = None @dataclass @@ -52,7 +52,7 @@ def detailed( pp_base: int = 2, multi_node_only: bool = False, runner: RunnerOption = "auto", - load_format: Optional[str] = None, + load_format: str | None = None, ): return PPTestSettings( parallel_setups=[ @@ -76,7 +76,7 @@ def fast( pp_base: int = 2, runner: RunnerOption = "auto", multi_node_only: bool = False, - load_format: Optional[str] = None, + load_format: str | None = None, ): return PPTestSettings( parallel_setups=[ @@ -244,7 +244,7 @@ def _compare_tp( tokenizer_mode = model_info.tokenizer_mode hf_overrides = model_info.hf_overrides hf_config = get_config(model_id, trust_remote_code) - skip_tokenizer_init = model_info.skip_tokenizer_init + require_embed_inputs = model_info.require_embed_inputs max_num_seqs = model_info.max_num_seqs dtype = "float16" @@ -299,16 +299,20 @@ def _compare_tp( common_args.extend(["--load-format", load_format]) if hf_overrides: common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) - if skip_tokenizer_init: - common_args.append("--skip-tokenizer-init") + if require_embed_inputs: + common_args.extend( + [ + "--skip-tokenizer-init", + "--enable-prompt-embeds", + "--enable-mm-embeds", + ] + ) if max_num_seqs: common_args.extend(["--max-num-seqs", f"{max_num_seqs}"]) if distributed_backend == "ray": - # For V1, test Ray Compiled Graph for all the tests + # Test Ray Compiled Graph for all the tests pp_env = { - "VLLM_USE_RAY_COMPILED_DAG": "1", - "VLLM_USE_RAY_SPMD_WORKER": "1", "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1", } # Temporary. Currently when zeromq + SPMD is used, it does not properly diff --git a/tests/distributed/test_pp_cudagraph.py b/tests/distributed/test_pp_cudagraph.py index 2c9f47464008..2f2b43cb4cc2 100644 --- a/tests/distributed/test_pp_cudagraph.py +++ b/tests/distributed/test_pp_cudagraph.py @@ -1,16 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - -from typing import TYPE_CHECKING - import pytest +from typing_extensions import LiteralString from ..utils import compare_two_settings, create_new_process_for_each_test -if TYPE_CHECKING: - from typing_extensions import LiteralString - @pytest.mark.parametrize( "PP_SIZE, MODEL_NAME", diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 4bab709fb589..c3085beeb356 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -18,7 +18,7 @@ graph_capture, init_distributed_environment, ) -from vllm.utils import update_environment_variables +from vllm.utils.system_utils import update_environment_variables def distributed_run(fn, world_size): diff --git a/tests/distributed/test_quick_all_reduce.py b/tests/distributed/test_quick_all_reduce.py index 2df88377345d..53d906bbc7bd 100644 --- a/tests/distributed/test_quick_all_reduce.py +++ b/tests/distributed/test_quick_all_reduce.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import multiprocessing import random import pytest @@ -8,6 +9,7 @@ import torch import torch.distributed as dist +from vllm import _custom_ops as ops from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa from vllm.distributed.parallel_state import get_tp_group, graph_capture from vllm.platforms import current_platform @@ -134,3 +136,88 @@ def test_custom_quick_allreduce( monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode) multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target) + + +def qr_variable_input(rank, world_size): + """ + When the tensor parallelism is set to 4 or 8, frequent changes + in the input shape can cause QuickReduce to hang (this issue + has been observed with the gpt_oss model). + """ + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + qr_max_size = None # MB + _ptr = ops.init_custom_qr(rank, world_size, qr_max_size) + ranks = [] + for i in range(world_size): + ranks.append(i) + dist.init_process_group( + backend="nccl", + init_method="tcp://127.0.0.1:29500", + rank=rank, + world_size=world_size, + ) + cpu_group = torch.distributed.new_group(ranks, backend="nccl") + + handle = ops.qr_get_handle(_ptr) + world_size = dist.get_world_size(group=cpu_group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=cpu_group) + ops.qr_open_handles(_ptr, handles) + + num = 1 + s1 = 1024 + while num < 50000: # 50000 is sufficient to identify issues. + dtype = torch.float16 + if num % 2 == 0: + s2 = 1024 + inp1 = torch.zeros( + (s1, s2), dtype=dtype, device=torch.cuda.current_device() + ) + else: + s2 = 2048 + inp1 = torch.ones((s1, s2), dtype=dtype, device=torch.cuda.current_device()) + result = torch.empty_like(inp1) + # FP = 0 INT8 = 1 INT6 = 2 INT4 = 3 NONE = 4 + ops.qr_all_reduce(_ptr, inp1, result, 3, cast_bf2half=True) + try: + if inp1[0, 0] == 0: + assert torch.all(result == 0) + else: + assert torch.all(result == world_size) + except AssertionError: + print("Assertion failed! Allreduce results are incorrect.") + raise + num += 1 + + +@pytest.mark.skipif( + not current_platform.is_rocm(), reason="only test quick allreduce for rocm" +) +@pytest.mark.parametrize("tp_size", [4, 8]) +@pytest.mark.parametrize("pipeline_parallel_size", [1]) +def test_custom_quick_allreduce_variable_input(tp_size, pipeline_parallel_size): + world_size = tp_size * pipeline_parallel_size + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + + multiprocessing.set_start_method("spawn", force=True) + # 60s is enough + timeout = 60 + processes = [] + for rank in range(tp_size): + p = multiprocessing.Process(target=qr_variable_input, args=(rank, tp_size)) + p.start() + processes.append((rank, p)) + for rank, p in processes: + p.join(timeout=timeout) + if p.is_alive(): + for r, proc in processes: + if proc.is_alive(): + proc.terminate() + proc.join() + raise RuntimeError(f"QuickReduce hang detected after {timeout} seconds!") + + +if __name__ == "__main__": + test_custom_quick_allreduce_variable_input(tp_size=4, pipeline_parallel_size=1) diff --git a/tests/distributed/test_same_node.py b/tests/distributed/test_same_node.py index baf75fd48c63..4444327f01da 100644 --- a/tests/distributed/test_same_node.py +++ b/tests/distributed/test_same_node.py @@ -3,11 +3,24 @@ import os +import torch import torch.distributed as dist from vllm.distributed.parallel_state import in_the_same_node_as from vllm.distributed.utils import StatelessProcessGroup -from vllm.utils import get_ip, get_open_port +from vllm.utils.network_utils import get_ip, get_open_port + + +def _run_test(pg): + test_result = all(in_the_same_node_as(pg, source_rank=0)) + + expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" + assert test_result == expected, f"Expected {expected}, got {test_result}" + if pg == dist.group.WORLD: + print("Same node test passed! when using torch distributed!") + else: + print("Same node test passed! when using StatelessProcessGroup!") + if __name__ == "__main__": dist.init_process_group(backend="gloo") @@ -25,11 +38,12 @@ stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size()) for pg in [dist.group.WORLD, stateless_pg]: - test_result = all(in_the_same_node_as(pg, source_rank=0)) - - expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" - assert test_result == expected, f"Expected {expected}, got {test_result}" - if pg == dist.group.WORLD: - print("Same node test passed! when using torch distributed!") + if os.environ.get("VLLM_TEST_WITH_DEFAULT_DEVICE_SET", "0") == "1": + default_devices = ["cpu"] + if torch.cuda.is_available(): + default_devices.append("cuda") + for device in default_devices: + torch.set_default_device(device) + _run_test(pg) else: - print("Same node test passed! when using StatelessProcessGroup!") + _run_test(pg) diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 1defd9690241..94b2b51211a6 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -11,12 +11,14 @@ import json import os from dataclasses import dataclass -from typing import Literal, NamedTuple, Optional +from typing import Literal, NamedTuple import pytest -from vllm.config import RunnerOption +from vllm.config.compilation import CompilationMode +from vllm.config.model import RunnerOption from vllm.logger import init_logger +from vllm.utils.torch_utils import is_torch_equal_or_newer from ..models.registry import HF_EXAMPLE_MODELS from ..utils import compare_two_settings, create_new_process_for_each_test @@ -36,7 +38,7 @@ class ParallelSetup(NamedTuple): class SPTestOptions(NamedTuple): multi_node_only: bool - load_format: Optional[str] = None + load_format: str | None = None @dataclass @@ -53,7 +55,7 @@ def detailed( pp_base: int = 1, multi_node_only: bool = False, runner: RunnerOption = "auto", - load_format: Optional[str] = None, + load_format: str | None = None, ): parallel_setups = [] for eager_mode_val in [False, True]: @@ -84,7 +86,7 @@ def fast( pp_base: int = 1, runner: RunnerOption = "auto", multi_node_only: bool = False, - load_format: Optional[str] = None, + load_format: str | None = None, ): parallel_setups = [] for eager_mode_val in [False, True]: @@ -115,7 +117,7 @@ def fp8_quant( pp_base: int = 1, runner: RunnerOption = "auto", multi_node_only: bool = False, - load_format: Optional[str] = None, + load_format: str | None = None, ): parallel_setups = [] for fusion_val in [False, True]: @@ -158,6 +160,7 @@ def _compare_sp( runner: RunnerOption, test_options: SPTestOptions, num_gpus_available: int, + use_inductor_graph_partition: bool, *, method: Literal["generate", "encode"], is_multimodal: bool, @@ -178,7 +181,7 @@ def _compare_sp( trust_remote_code = model_info.trust_remote_code tokenizer_mode = model_info.tokenizer_mode hf_overrides = model_info.hf_overrides - skip_tokenizer_init = model_info.skip_tokenizer_init + require_embed_inputs = model_info.require_embed_inputs if load_format == "dummy": # Avoid OOM @@ -230,11 +233,17 @@ def _compare_sp( common_args.extend(["--load-format", load_format]) if hf_overrides: common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) - if skip_tokenizer_init: - common_args.append("--skip-tokenizer-init") + if require_embed_inputs: + common_args.extend( + [ + "--skip-tokenizer-init", + "--enable-prompt-embeds", + "--enable-mm-embeds", + ] + ) compilation_config = { - "level": 3, + "mode": CompilationMode.VLLM_COMPILE, "custom_ops": ["+rms_norm"], "compile_sizes": [4, 8], "pass_config": { @@ -242,6 +251,7 @@ def _compare_sp( "enable_fusion": enable_fusion, "enable_noop": True, }, + "use_inductor_graph_partition": use_inductor_graph_partition, } tp_sp_args = [ @@ -269,14 +279,14 @@ def _compare_sp( SP_TEXT_GENERATION_MODELS = { # [Decoder-only] - "meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(), + "hmellor/tiny-random-LlamaForCausalLM": SPTestSettings.fast(), "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8": SPTestSettings.fp8_quant(), } SP_TEST_MODELS = [ # TODO support other models # [LANGUAGE GENERATION] - "meta-llama/Llama-3.2-1B-Instruct", + "hmellor/tiny-random-LlamaForCausalLM", "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", ] @@ -296,6 +306,7 @@ def _compare_sp( if model_id in SP_TEST_MODELS ], ) +@pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) @create_new_process_for_each_test() def test_tp_sp_generation( model_id: str, @@ -304,7 +315,11 @@ def test_tp_sp_generation( runner: RunnerOption, test_options: SPTestOptions, num_gpus_available, + use_inductor_graph_partition: bool, ): + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + _compare_sp( model_id, parallel_setup, @@ -312,6 +327,7 @@ def test_tp_sp_generation( runner, test_options, num_gpus_available, + use_inductor_graph_partition, method="generate", is_multimodal=False, ) diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index cdea1bfe8f28..a7ace62e1b54 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -10,7 +10,8 @@ from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.distributed.utils import StatelessProcessGroup -from vllm.utils import get_open_port, update_environment_variables +from vllm.utils.network_utils import get_open_port +from vllm.utils.system_utils import update_environment_variables def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]: diff --git a/tests/distributed/test_symm_mem_allreduce.py b/tests/distributed/test_symm_mem_allreduce.py index e669b81b04f0..b8f04cf8e62c 100644 --- a/tests/distributed/test_symm_mem_allreduce.py +++ b/tests/distributed/test_symm_mem_allreduce.py @@ -23,7 +23,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.platforms import current_platform -from vllm.utils import update_environment_variables +from vllm.utils.system_utils import update_environment_variables torch.manual_seed(42) random.seed(44) diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py index 2a6936fcd4c2..8289f697fea6 100644 --- a/tests/distributed/test_utils.py +++ b/tests/distributed/test_utils.py @@ -10,11 +10,9 @@ import vllm.envs as envs from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.utils import StatelessProcessGroup -from vllm.utils import ( - cuda_device_count_stateless, - get_open_port, - update_environment_variables, -) +from vllm.utils.network_utils import get_open_port +from vllm.utils.system_utils import update_environment_variables +from vllm.utils.torch_utils import cuda_device_count_stateless from ..utils import multi_gpu_test diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 9d367349fc2e..bcee0eb3d6fa 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -5,7 +5,7 @@ from argparse import ArgumentError from contextlib import nullcontext from dataclasses import dataclass, field -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal import pytest @@ -115,9 +115,9 @@ class NestedConfig: class DummyConfig: regular_bool: bool = True """Regular bool with default True""" - optional_bool: Optional[bool] = None + optional_bool: bool | None = None """Optional bool with default None""" - optional_literal: Optional[Literal["x", "y"]] = None + optional_literal: Literal["x", "y"] | None = None """Optional literal with default None""" tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3)) """Tuple with variable length""" @@ -127,8 +127,10 @@ class DummyConfig: """List with variable length""" list_literal: list[Literal[1, 2]] = field(default_factory=list) """List with literal choices""" - list_union: list[Union[str, type[object]]] = field(default_factory=list) + list_union: list[str | type[object]] = field(default_factory=list) """List with union type""" + set_n: set[int] = field(default_factory=lambda: {1, 2, 3}) + """Set with variable length""" literal_literal: Literal[Literal[1], Literal[2]] = 1 """Literal of literals with default 1""" json_tip: dict = field(default_factory=dict) @@ -152,11 +154,11 @@ def test_is_not_builtin(type_hint, expected): ("type_hint", "expected"), [ (Annotated[int, "annotation"], {int}), - (Optional[int], {int, type(None)}), - (Annotated[Optional[int], "annotation"], {int, type(None)}), - (Optional[Annotated[int, "annotation"]], {int, type(None)}), + (int | None, {int, type(None)}), + (Annotated[int | None, "annotation"], {int, type(None)}), + (Annotated[int, "annotation"] | None, {int, type(None)}), ], - ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"], + ids=["Annotated", "or_None", "Annotated_or_None", "or_None_Annotated"], ) def test_get_type_hints(type_hint, expected): assert get_type_hints(type_hint) == expected @@ -184,6 +186,9 @@ def test_get_kwargs(): # lists with unions should become str type. # If not, we cannot know which type to use for parsing assert kwargs["list_union"]["type"] is str + # sets should work like lists + assert kwargs["set_n"]["type"] is int + assert kwargs["set_n"]["nargs"] == "+" # literals of literals should have merged choices assert kwargs["literal_literal"]["choices"] == [1, 2] # dict should have json tip in help @@ -226,30 +231,30 @@ def test_compilation_config(): # set to O3 args = parser.parse_args(["-O0"]) - assert args.compilation_config.level == 0 + assert args.compilation_config.mode == 0 # set to O 3 (space) args = parser.parse_args(["-O", "1"]) - assert args.compilation_config.level == 1 + assert args.compilation_config.mode == 1 # set to O 3 (equals) args = parser.parse_args(["-O=2"]) - assert args.compilation_config.level == 2 + assert args.compilation_config.mode == 2 - # set to O.level 3 - args = parser.parse_args(["-O.level", "3"]) - assert args.compilation_config.level == 3 + # set to O.mode 3 + args = parser.parse_args(["-O.mode", "3"]) + assert args.compilation_config.mode == 3 # set to string form of a dict args = parser.parse_args( [ "-O", - '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' + '{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' '"use_inductor": false}', ] ) assert ( - args.compilation_config.level == 3 + args.compilation_config.mode == 3 and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] and not args.compilation_config.use_inductor ) @@ -258,12 +263,12 @@ def test_compilation_config(): args = parser.parse_args( [ "--compilation-config=" - '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' + '{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' '"use_inductor": true}', ] ) assert ( - args.compilation_config.level == 3 + args.compilation_config.mode == 3 and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] and args.compilation_config.use_inductor ) diff --git a/vllm/executor/__init__.py b/tests/entrypoints/anthropic/__init__.py similarity index 100% rename from vllm/executor/__init__.py rename to tests/entrypoints/anthropic/__init__.py diff --git a/tests/entrypoints/anthropic/test_messages.py b/tests/entrypoints/anthropic/test_messages.py new file mode 100644 index 000000000000..4e35554b4e33 --- /dev/null +++ b/tests/entrypoints/anthropic/test_messages.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import anthropic +import pytest +import pytest_asyncio + +from ...utils import RemoteAnthropicServer + +MODEL_NAME = "Qwen/Qwen3-0.6B" + + +@pytest.fixture(scope="module") +def server(): # noqa: F811 + args = [ + "--max-model-len", + "2048", + "--enforce-eager", + "--enable-auto-tool-choice", + "--tool-call-parser", + "hermes", + "--served-model-name", + "claude-3-7-sonnet-latest", + ] + + with RemoteAnthropicServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_simple_messages(client: anthropic.AsyncAnthropic): + resp = await client.messages.create( + model="claude-3-7-sonnet-latest", + max_tokens=1024, + messages=[{"role": "user", "content": "how are you!"}], + ) + assert resp.stop_reason == "end_turn" + assert resp.role == "assistant" + + print(f"Anthropic response: {resp.model_dump_json()}") + + +@pytest.mark.asyncio +async def test_system_message(client: anthropic.AsyncAnthropic): + resp = await client.messages.create( + model="claude-3-7-sonnet-latest", + max_tokens=1024, + system="you are a helpful assistant", + messages=[{"role": "user", "content": "how are you!"}], + ) + assert resp.stop_reason == "end_turn" + assert resp.role == "assistant" + + print(f"Anthropic response: {resp.model_dump_json()}") + + +@pytest.mark.asyncio +async def test_anthropic_streaming(client: anthropic.AsyncAnthropic): + resp = await client.messages.create( + model="claude-3-7-sonnet-latest", + max_tokens=1024, + messages=[{"role": "user", "content": "how are you!"}], + stream=True, + ) + + async for chunk in resp: + print(chunk.model_dump_json()) + + +@pytest.mark.asyncio +async def test_anthropic_tool_call(client: anthropic.AsyncAnthropic): + resp = await client.messages.create( + model="claude-3-7-sonnet-latest", + max_tokens=1024, + messages=[ + {"role": "user", "content": "What's the weather like in New York today?"} + ], + tools=[ + { + "name": "get_current_weather", + "description": "Useful for querying the weather in a specified city.", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City or region, for example: " + "New York, London, Tokyo, etc.", + } + }, + "required": ["location"], + }, + } + ], + stream=False, + ) + assert resp.stop_reason == "tool_use" + assert resp.role == "assistant" + + print(f"Anthropic response: {resp.model_dump_json()}") + + @pytest.mark.asyncio + async def test_anthropic_tool_call_streaming(client: anthropic.AsyncAnthropic): + resp = await client.messages.create( + model="claude-3-7-sonnet-latest", + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "What's the weather like in New York today?", + } + ], + tools=[ + { + "name": "get_current_weather", + "description": "Useful for querying the weather " + "in a specified city.", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City or region, for example: " + "New York, London, Tokyo, etc.", + } + }, + "required": ["location"], + }, + } + ], + stream=True, + ) + + async for chunk in resp: + print(chunk.model_dump_json()) diff --git a/tests/entrypoints/llm/test_collective_rpc.py b/tests/entrypoints/llm/test_collective_rpc.py index 937aa5c13246..747676ac9567 100644 --- a/tests/entrypoints/llm/test_collective_rpc.py +++ b/tests/entrypoints/llm/test_collective_rpc.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest +import torch from vllm import LLM @@ -12,6 +13,8 @@ @pytest.mark.parametrize("backend", ["mp", "ray"]) @create_new_process_for_each_test() def test_collective_rpc(tp_size, backend, monkeypatch): + if torch.cuda.device_count() < tp_size: + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") if tp_size == 1 and backend == "ray": pytest.skip("Skip duplicate test case") if tp_size == 1: @@ -24,7 +27,7 @@ def echo_rank(self): monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") llm = LLM( - model="meta-llama/Llama-3.2-1B-Instruct", + model="hmellor/tiny-random-LlamaForCausalLM", enforce_eager=True, load_format="dummy", tensor_parallel_size=tp_size, diff --git a/tests/entrypoints/llm/test_mm_cache_stats.py b/tests/entrypoints/llm/test_mm_cache_stats.py new file mode 100644 index 000000000000..e5ee99124409 --- /dev/null +++ b/tests/entrypoints/llm/test_mm_cache_stats.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import logging + +import pytest +import regex as re + +from vllm import LLM +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.v1.metrics import loggers as stat_loggers +from vllm.v1.metrics.reader import Counter, Metric + +from ..openai.test_vision import TEST_IMAGE_ASSETS + + +def _make_messages(image_url: str) -> list[ChatCompletionMessageParam]: + return [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_url}, + }, + ], + } + ] + + +def _get_counter_value(metrics: list[Metric], name: str): + metric = next(m for m in metrics if m.name == name) + assert isinstance(metric, Counter) + return metric.value + + +def _get_mm_cache_stats(metrics: list[Metric]): + mm_cache_queries = _get_counter_value(metrics, "vllm:mm_cache_queries") + mm_cache_hits = _get_counter_value(metrics, "vllm:mm_cache_hits") + + return mm_cache_queries, mm_cache_hits + + +def _get_mm_cache_log(llm: LLM, caplog_vllm: pytest.LogCaptureFixture) -> float: + caplog_vllm.clear() + with caplog_vllm.at_level(logging.INFO, logger=stat_loggers.__name__): + llm.llm_engine.do_log_stats() + + assert len(caplog_vllm.records) == 1 + msg = caplog_vllm.records[0].getMessage() + + assert "MM cache hit rate" in msg + match = re.search(r"MM cache hit rate: ([0-9.]+)%", msg) + assert match is not None + return float(match.group(1)) + + +@pytest.mark.parametrize("image_urls", [TEST_IMAGE_ASSETS[:2]], indirect=True) +@pytest.mark.parametrize("mm_processor_cache_type", ["lru", "shm"]) +def test_mm_cache_stats( + num_gpus_available, + image_urls, + mm_processor_cache_type, + caplog_vllm, +): + llm = LLM( + model="llava-hf/llava-1.5-7b-hf", + max_model_len=4096, + max_num_seqs=5, + enforce_eager=True, + mm_processor_cache_type=mm_processor_cache_type, + disable_log_stats=False, + limit_mm_per_prompt={"image": 2}, + ) + + llm.chat(_make_messages(image_urls[0])) + assert _get_mm_cache_stats(llm.get_metrics()) == (1, 0) + assert _get_mm_cache_log(llm, caplog_vllm) == pytest.approx(0.0) + + llm.chat(_make_messages(image_urls[1])) + assert _get_mm_cache_stats(llm.get_metrics()) == (2, 0) + assert _get_mm_cache_log(llm, caplog_vllm) == pytest.approx(0.0) + + llm.chat(_make_messages(image_urls[0])) + assert _get_mm_cache_stats(llm.get_metrics()) == (3, 1) + assert _get_mm_cache_log(llm, caplog_vllm) == pytest.approx(33.3) + + # NOTE: This only resets hit rate stats in CachingMetrics + # The raw queries and hits counts remain unaffected + llm.reset_mm_cache() + + llm.chat(_make_messages(image_urls[0])) + assert _get_mm_cache_stats(llm.get_metrics()) == (4, 1) + assert _get_mm_cache_log(llm, caplog_vllm) == pytest.approx(0.0) + + llm.chat(_make_messages(image_urls[1])) + assert _get_mm_cache_stats(llm.get_metrics()) == (5, 1) + assert _get_mm_cache_log(llm, caplog_vllm) == pytest.approx(0.0) diff --git a/tests/entrypoints/llm/test_prompt_validation.py b/tests/entrypoints/llm/test_prompt_validation.py index 81126a4f16f9..c17486d962f3 100644 --- a/tests/entrypoints/llm/test_prompt_validation.py +++ b/tests/entrypoints/llm/test_prompt_validation.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest +import torch from vllm import LLM @@ -12,8 +13,22 @@ def test_empty_prompt(): llm.generate([""]) -@pytest.mark.skip_v1 def test_out_of_vocab_token(): llm = LLM(model="openai-community/gpt2", enforce_eager=True) with pytest.raises(ValueError, match="out of vocabulary"): llm.generate({"prompt_token_ids": [999999]}) + + +def test_require_mm_embeds(): + llm = LLM( + model="llava-hf/llava-1.5-7b-hf", + enforce_eager=True, + enable_mm_embeds=False, + ) + with pytest.raises(ValueError, match="--enable-mm-embeds"): + llm.generate( + { + "prompt": "<image>", + "multi_modal_data": {"image": torch.empty(1, 1, 1)}, + } + ) diff --git a/tests/entrypoints/openai/test_async_tokenization.py b/tests/entrypoints/openai/test_async_tokenization.py index 5df859df42da..682420a83a44 100644 --- a/tests/entrypoints/openai/test_async_tokenization.py +++ b/tests/entrypoints/openai/test_async_tokenization.py @@ -3,7 +3,7 @@ import asyncio import random -from typing import Callable +from collections.abc import Callable import openai import pytest diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index a96f0134c2ff..a2d8993441fc 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -53,22 +53,35 @@ def base64_encoded_audio() -> dict[str, str]: } -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) -async def test_single_chat_session_audio( - client: openai.AsyncOpenAI, model_name: str, audio_url: str +def dummy_messages_from_audio_url( + audio_urls: str | list[str], + content_text: str = "What's happening in this audio?", ): - messages = [ + if isinstance(audio_urls, str): + audio_urls = [audio_urls] + + return [ { "role": "user", "content": [ - {"type": "audio_url", "audio_url": {"url": audio_url}}, - {"type": "text", "text": "What's happening in this audio?"}, + *( + {"type": "audio_url", "audio_url": {"url": audio_url}} + for audio_url in audio_urls + ), + {"type": "text", "text": content_text}, ], } ] + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) +async def test_single_chat_session_audio( + client: openai.AsyncOpenAI, model_name: str, audio_url: str +): + messages = dummy_messages_from_audio_url(audio_url) + # test single completion chat_completion = await client.chat.completions.create( model=model_name, @@ -138,20 +151,9 @@ async def test_single_chat_session_audio_base64encoded( audio_url: str, base64_encoded_audio: dict[str, str], ): - messages = [ - { - "role": "user", - "content": [ - { - "type": "audio_url", - "audio_url": { - "url": f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}" # noqa: E501 - }, - }, - {"type": "text", "text": "What's happening in this audio?"}, - ], - } - ] + messages = dummy_messages_from_audio_url( + f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}" + ) # test single completion chat_completion = await client.chat.completions.create( @@ -252,15 +254,7 @@ async def test_single_chat_session_input_audio( async def test_chat_streaming_audio( client: openai.AsyncOpenAI, model_name: str, audio_url: str ): - messages = [ - { - "role": "user", - "content": [ - {"type": "audio_url", "audio_url": {"url": audio_url}}, - {"type": "text", "text": "What's happening in this audio?"}, - ], - } - ] + messages = dummy_messages_from_audio_url(audio_url) # test single completion chat_completion = await client.chat.completions.create( @@ -365,18 +359,7 @@ async def test_chat_streaming_input_audio( async def test_multi_audio_input( client: openai.AsyncOpenAI, model_name: str, audio_urls: list[str] ): - messages = [ - { - "role": "user", - "content": [ - *( - {"type": "audio_url", "audio_url": {"url": audio_url}} - for audio_url in audio_urls - ), - {"type": "text", "text": "What's happening in this audio?"}, - ], - } - ] + messages = dummy_messages_from_audio_url(audio_urls) if len(audio_urls) > MAXIMUM_AUDIOS: with pytest.raises(openai.BadRequestError): # test multi-audio input diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index 50ec87b4464f..e63a6f10cbc7 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -3,12 +3,15 @@ import asyncio from http import HTTPStatus +from unittest.mock import AsyncMock, Mock import openai import pytest import pytest_asyncio import requests +from fastapi import Request +from vllm.v1.engine.exceptions import EngineDeadError from vllm.version import __version__ as VLLM_VERSION from ...utils import RemoteOpenAIServer @@ -224,3 +227,24 @@ def make_long_completion_request(): response = requests.get(server.url_for("load")) assert response.status_code == HTTPStatus.OK assert response.json().get("server_load") == 0 + + +@pytest.mark.asyncio +async def test_health_check_engine_dead_error(): + # Import the health function directly to test it in isolation + from vllm.entrypoints.openai.api_server import health + + # Create a mock request that simulates what FastAPI would provide + mock_request = Mock(spec=Request) + mock_app_state = Mock() + mock_engine_client = AsyncMock() + mock_engine_client.check_health.side_effect = EngineDeadError() + mock_app_state.engine_client = mock_engine_client + mock_request.app.state = mock_app_state + + # Test the health function directly with our mocked request + # This simulates what would happen if the engine dies + response = await health(mock_request) + + # Assert that it returns 503 Service Unavailable + assert response.status_code == 503 diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index d110234d60ac..d25958f602b3 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -3,7 +3,6 @@ # imports for structured outputs tests import json -from typing import Optional import jsonschema import openai # use the official client for correctness check @@ -176,7 +175,7 @@ async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI, model_name: st [(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)], ) async def test_prompt_logprobs_chat( - client: openai.AsyncOpenAI, model_name: str, prompt_logprobs: Optional[int] + client: openai.AsyncOpenAI, model_name: str, prompt_logprobs: int | None ): params: dict = { "messages": [ @@ -369,7 +368,7 @@ async def test_chat_completion_stream_options( assert chunk.usage is None else: assert chunk.usage is None - final_chunk = await stream.__anext__() + final_chunk = await anext(stream) assert final_chunk.usage is not None assert final_chunk.usage.prompt_tokens > 0 assert final_chunk.usage.completion_tokens > 0 @@ -600,145 +599,6 @@ async def test_structured_outputs_choice_chat_logprobs( assert item.logprob >= -9999.0, f"Failed (top_logprobs={top_logprobs})" -@pytest.mark.asyncio -async def test_named_tool_use( - client: openai.AsyncOpenAI, - sample_json_schema, -): - messages = [ - {"role": "system", "content": "you are a helpful assistant"}, - { - "role": "user", - "content": ( - "Give an example JSON for an employee profile using the specified tool." - ), - }, - ] - tools = [ - { - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema, - }, - } - ] - tool_choice = {"type": "function", "function": {"name": "dummy_function_name"}} - - # non-streaming - - chat_completion = await client.chat.completions.create( - model=MODEL_NAME, - messages=messages, - max_completion_tokens=1000, - tools=tools, - tool_choice=tool_choice, - ) - message = chat_completion.choices[0].message - assert len(message.content) == 0 - json_string = message.tool_calls[0].function.arguments - json1 = json.loads(json_string) - jsonschema.validate(instance=json1, schema=sample_json_schema) - - messages.append({"role": "assistant", "content": json_string}) - messages.append( - {"role": "user", "content": "Give me another one with a different name and age"} - ) - - # streaming - - stream = await client.chat.completions.create( - model=MODEL_NAME, - messages=messages, - max_completion_tokens=1000, - tools=tools, - tool_choice=tool_choice, - stream=True, - ) - - output = [] - finish_reason_count = 0 - async for chunk in stream: - delta = chunk.choices[0].delta - if delta.role: - assert delta.role == "assistant" - assert delta.content is None or len(delta.content) == 0 - if delta.tool_calls: - output.append(delta.tool_calls[0].function.arguments) - if chunk.choices[0].finish_reason is not None: - finish_reason_count += 1 - # finish reason should only return in last block - assert finish_reason_count == 1 - json2 = json.loads("".join(output)) - jsonschema.validate(instance=json2, schema=sample_json_schema) - assert json1["name"] != json2["name"] - assert json1["age"] != json2["age"] - - -@pytest.mark.asyncio -async def test_inconsistent_tool_choice_and_tools( - client: openai.AsyncOpenAI, sample_json_schema -): - messages = [ - {"role": "system", "content": "you are a helpful assistant"}, - { - "role": "user", - "content": f"Give an example JSON for an employee profile that " - f"fits this schema: {sample_json_schema}", - }, - ] - - with pytest.raises(openai.BadRequestError): - await client.chat.completions.create( - model=MODEL_NAME, - messages=messages, - max_completion_tokens=1000, - tool_choice={ - "type": "function", - "function": {"name": "dummy_function_name"}, - }, - ) - - with pytest.raises(openai.BadRequestError): - await client.chat.completions.create( - model=MODEL_NAME, - messages=messages, - max_completion_tokens=1000, - tools=[ - { - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema, - }, - } - ], - tool_choice={ - "type": "function", - "function": {"name": "nondefined_function_name"}, - }, - ) - with pytest.raises(openai.BadRequestError): - await client.chat.completions.create( - model=MODEL_NAME, - messages=messages, - max_completion_tokens=1000, - tools=[ - { - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema, - }, - } - ], - tool_choice={}, - ) - - @pytest.mark.asyncio async def test_response_format_json_object(client: openai.AsyncOpenAI): for _ in range(2): diff --git a/tests/entrypoints/openai/test_chat_echo.py b/tests/entrypoints/openai/test_chat_echo.py index a9c9c8e3dfe8..b3b8b700336d 100644 --- a/tests/entrypoints/openai/test_chat_echo.py +++ b/tests/entrypoints/openai/test_chat_echo.py @@ -7,12 +7,23 @@ import pytest import pytest_asyncio +from vllm.config import ModelConfig + from ...utils import RemoteOpenAIServer # # any model with a chat template should work here MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" +def get_vocab_size(model_name): + config = ModelConfig( + model=model_name, + seed=0, + dtype="float16", + ) + return config.get_vocab_size() + + @pytest.fixture(scope="module") def server(): args = [ @@ -107,6 +118,7 @@ async def test_top_logprobs(client: openai.AsyncOpenAI): completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, + max_tokens=1, extra_body={ "top_logprobs": -1, "logprobs": "true", @@ -115,3 +127,6 @@ async def test_top_logprobs(client: openai.AsyncOpenAI): assert completion.choices[0].logprobs is not None assert completion.choices[0].logprobs.content is not None assert len(completion.choices[0].logprobs.content) > 0 + assert len( + completion.choices[0].logprobs.content[0].top_logprobs + ) == get_vocab_size(MODEL_NAME) diff --git a/tests/entrypoints/openai/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py index d1202a59752b..ee79ed59c410 100644 --- a/tests/entrypoints/openai/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -114,7 +114,9 @@ def test_get_gen_prompt( trust_remote_code=model_info.trust_remote_code, revision=model_info.revision, hf_overrides=model_info.hf_overrides, - skip_tokenizer_init=model_info.skip_tokenizer_init, + skip_tokenizer_init=model_info.require_embed_inputs, + enable_prompt_embeds=model_info.require_embed_inputs, + enable_mm_embeds=model_info.require_embed_inputs, enforce_eager=model_info.enforce_eager, dtype=model_info.dtype, ) diff --git a/tests/entrypoints/openai/test_completion_with_function_calling.py b/tests/entrypoints/openai/test_completion_with_function_calling.py index e64f68cad7c8..6d8db361a57d 100644 --- a/tests/entrypoints/openai/test_completion_with_function_calling.py +++ b/tests/entrypoints/openai/test_completion_with_function_calling.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import datetime -from typing import Union +import json +import jsonschema import openai # use the official client for correctness check import pytest import pytest_asyncio @@ -166,7 +167,7 @@ async def test_function_tool_use( client: openai.AsyncOpenAI, model_name: str, stream: bool, - tool_choice: Union[str, dict], + tool_choice: str | dict, enable_thinking: bool, ): if not stream: @@ -195,11 +196,19 @@ async def test_function_tool_use( ) output = [] + reasoning = [] async for chunk in output_stream: - if chunk.choices and chunk.choices[0].delta.tool_calls: - output.extend(chunk.choices[0].delta.tool_calls) + if chunk.choices: + if enable_thinking and getattr( + chunk.choices[0].delta, "reasoning_content", None + ): + reasoning.append(chunk.choices[0].delta.reasoning_content) + if chunk.choices[0].delta.tool_calls: + output.extend(chunk.choices[0].delta.tool_calls) assert len(output) > 0 + if enable_thinking: + assert len(reasoning) > 0 @pytest.fixture(scope="module") @@ -248,10 +257,10 @@ async def test_tool_id_kimi_k2( ) assert chat_completion.choices[0].message.tool_calls is not None assert len(chat_completion.choices[0].message.tool_calls) > 0 - assert ( - chat_completion.choices[0].message.tool_calls[0].id - == "functions.get_current_weather:0" - ) + assert chat_completion.choices[0].message.tool_calls[0].id in [ + "functions.get_current_weather:0", + "functions.get_forecast:1", + ] else: # Streaming test output_stream = await k2_client.chat.completions.create( @@ -267,7 +276,10 @@ async def test_tool_id_kimi_k2( if chunk.choices and chunk.choices[0].delta.tool_calls: output.extend(chunk.choices[0].delta.tool_calls) for o in output: - assert o.id is None or o.id == "functions.get_current_weather:0" + assert o.id is None or o.id in [ + "functions.get_current_weather:0", + "functions.get_forecast:1", + ] @pytest.mark.asyncio @@ -331,3 +343,144 @@ async def test_no_args_tool_call( else: # No tool called — just print model's direct reply assert message.content is not None + + +@pytest.mark.asyncio +async def test_named_tool_use( + client: openai.AsyncOpenAI, + sample_json_schema, +): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": ( + "Give an example JSON for an employee profile using the specified tool." + ), + }, + ] + tools = [ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, + } + ] + tool_choice = {"type": "function", "function": {"name": "dummy_function_name"}} + + # non-streaming + + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_completion_tokens=1000, + tools=tools, + temperature=0.0, + tool_choice=tool_choice, + ) + message = chat_completion.choices[0].message + assert len(message.content) == 0 + json_string = message.tool_calls[0].function.arguments + json1 = json.loads(json_string) + jsonschema.validate(instance=json1, schema=sample_json_schema) + + messages.append({"role": "assistant", "content": json_string}) + messages.append( + {"role": "user", "content": "Give me another one with a different name and age"} + ) + + # streaming + + stream = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_completion_tokens=1000, + tools=tools, + tool_choice=tool_choice, + temperature=0.0, + stream=True, + ) + + output = [] + finish_reason_count = 0 + async for chunk in stream: + delta = chunk.choices[0].delta + if delta.role: + assert delta.role == "assistant" + assert delta.content is None or len(delta.content) == 0 + if delta.tool_calls: + output.append(delta.tool_calls[0].function.arguments) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # finish reason should only return in last block + assert finish_reason_count == 1 + json2 = json.loads("".join(output)) + jsonschema.validate(instance=json2, schema=sample_json_schema) + assert json1["name"] != json2["name"] + assert json1["age"] != json2["age"] + + +@pytest.mark.asyncio +async def test_inconsistent_tool_choice_and_tools( + client: openai.AsyncOpenAI, sample_json_schema +): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": f"Give an example JSON for an employee profile that " + f"fits this schema: {sample_json_schema}", + }, + ] + + with pytest.raises(openai.BadRequestError): + await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_completion_tokens=1000, + tool_choice={ + "type": "function", + "function": {"name": "dummy_function_name"}, + }, + ) + + with pytest.raises(openai.BadRequestError): + await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_completion_tokens=1000, + tools=[ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, + } + ], + tool_choice={ + "type": "function", + "function": {"name": "nondefined_function_name"}, + }, + ) + with pytest.raises(openai.BadRequestError): + await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_completion_tokens=1000, + tools=[ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, + } + ], + tool_choice={}, + ) diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py index 3ed98ffe0e39..0a057b1848ad 100644 --- a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py +++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py @@ -292,3 +292,16 @@ async def test_prompt_logprobs_raises_error( temperature=0.0, extra_body={"prompt_embeds": encoded_embeds, "prompt_logprobs": True}, ) + + +@pytest.mark.asyncio +async def test_empty_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, +) -> None: + await client_with_prompt_embeds.completions.create( + model=MODEL_NAME, + prompt="Hello", + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": []}, + ) diff --git a/tests/entrypoints/openai/test_enable_force_include_usage.py b/tests/entrypoints/openai/test_enable_force_include_usage.py new file mode 100644 index 000000000000..3ddf2308eb1d --- /dev/null +++ b/tests/entrypoints/openai/test_enable_force_include_usage.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import openai +import pytest +import pytest_asyncio + +from ...utils import RemoteOpenAIServer + + +@pytest.fixture(scope="module") +def chat_server_with_force_include_usage(request): # noqa: F811 + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "128", + "--enforce-eager", + "--max-num-seqs", + "1", + "--enable-force-include-usage", + "--port", + "55857", + "--gpu-memory-utilization", + "0.2", + ] + + with RemoteOpenAIServer("Qwen/Qwen3-0.6B", args, auto_port=False) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def chat_client_with_force_include_usage(chat_server_with_force_include_usage): + async with chat_server_with_force_include_usage.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_chat_with_enable_force_include_usage( + chat_client_with_force_include_usage: openai.AsyncOpenAI, +): + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ] + + stream = await chat_client_with_force_include_usage.chat.completions.create( + model="Qwen/Qwen3-0.6B", + messages=messages, + max_completion_tokens=10, + extra_body=dict(min_tokens=10), + temperature=0.0, + stream=True, + ) + last_completion_tokens = 0 + async for chunk in stream: + if not len(chunk.choices): + assert chunk.usage.prompt_tokens >= 0 + assert ( + last_completion_tokens == 0 + or chunk.usage.completion_tokens > last_completion_tokens + or ( + not chunk.choices + and chunk.usage.completion_tokens == last_completion_tokens + ) + ) + assert chunk.usage.total_tokens == ( + chunk.usage.prompt_tokens + chunk.usage.completion_tokens + ) + else: + assert chunk.usage is None + + +@pytest.fixture(scope="module") +def transcription_server_with_force_include_usage(): + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-num-seqs", + "1", + "--enforce-eager", + "--enable-force-include-usage", + "--gpu-memory-utilization", + "0.2", + ] + + with RemoteOpenAIServer("openai/whisper-large-v3-turbo", args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def transcription_client_with_force_include_usage( + transcription_server_with_force_include_usage, +): + async with ( + transcription_server_with_force_include_usage.get_async_client() as async_client + ): + yield async_client + + +@pytest.mark.asyncio +async def test_transcription_with_enable_force_include_usage( + transcription_client_with_force_include_usage, winning_call +): + res = ( + await transcription_client_with_force_include_usage.audio.transcriptions.create( + model="openai/whisper-large-v3-turbo", + file=winning_call, + language="en", + temperature=0.0, + stream=True, + timeout=30, + ) + ) + + async for chunk in res: + if not len(chunk.choices): + # final usage sent + usage = chunk.usage + assert isinstance(usage, dict) + assert usage["prompt_tokens"] > 0 + assert usage["completion_tokens"] > 0 + assert usage["total_tokens"] > 0 + else: + assert not hasattr(chunk, "usage") diff --git a/tests/entrypoints/openai/test_gptoss_structural_tags_integration.py b/tests/entrypoints/openai/test_gptoss_structural_tags_integration.py new file mode 100644 index 000000000000..fbfae4f268d5 --- /dev/null +++ b/tests/entrypoints/openai/test_gptoss_structural_tags_integration.py @@ -0,0 +1,280 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Integration tests for GPT-OSS structural tags functionality (PR #25515).""" + +import json +from unittest.mock import Mock + +import pytest + +from vllm.entrypoints.openai.protocol import ( + StructuredOutputsParams, +) +from vllm.entrypoints.tool_server import ToolServer +from vllm.reasoning.gptoss_reasoning_parser import ( + GptOssReasoningParser, +) + + +class TestGptOssStructuralTagsIntegration: + """Integration tests for structural tags in GPT-OSS tool calls.""" + + @pytest.fixture + def mock_tokenizer(self): + """Create a mock tokenizer.""" + tokenizer = Mock() + tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5]) + return tokenizer + + @pytest.fixture + def gptoss_parser(self, mock_tokenizer): + """Create a real GptOssReasoningParser instance.""" + return GptOssReasoningParser(mock_tokenizer) + + @pytest.fixture + def tool_server_with_python(self): + """Create a tool server with Python tool enabled.""" + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock(side_effect=lambda tool: tool == "python") + return tool_server + + @pytest.fixture + def tool_server_empty(self): + """Create a tool server with no tools.""" + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock(return_value=False) + return tool_server + + def test_end_to_end_no_tools(self, gptoss_parser): + """Test end-to-end flow when no tools are available.""" + # Test the parser directly + result = gptoss_parser.prepare_structured_tag(None, None) + parsed_result = json.loads(result) + + # Verify basic structure + assert parsed_result["type"] == "structural_tag" + assert parsed_result["format"]["type"] == "triggered_tags" + assert len(parsed_result["format"]["tags"]) == 1 + + # Verify only analysis channel is allowed + analysis_tag = parsed_result["format"]["tags"][0] + assert analysis_tag["begin"] == "<|channel|>analysis<|message|>" + assert analysis_tag["content"]["type"] == "any_text" + assert analysis_tag["end"] == "<|end|>" + + # Verify triggers + assert parsed_result["format"]["triggers"] == ["<|channel|>analysis"] + assert parsed_result["format"]["stop_after_first"] is False + + def test_end_to_end_with_python_tool(self, gptoss_parser, tool_server_with_python): + """Test end-to-end flow with Python tool enabled.""" + result = gptoss_parser.prepare_structured_tag(None, tool_server_with_python) + parsed_result = json.loads(result) + + # Should have analysis tag + 2 python tags + assert len(parsed_result["format"]["tags"]) == 3 + + # Verify all expected tags are present + tag_begins = [tag["begin"] for tag in parsed_result["format"]["tags"]] + expected_begins = [ + "<|channel|>analysis<|message|>", + "<|channel|>commentary to=python", + "<|channel|>analysis to=python", + ] + + for expected in expected_begins: + assert expected in tag_begins + + # Verify triggers include commentary + assert "<|channel|>analysis" in parsed_result["format"]["triggers"] + assert "<|channel|>commentary to=" in parsed_result["format"]["triggers"] + + def test_structured_outputs_params_integration( + self, gptoss_parser, tool_server_with_python + ): + """Test integration with StructuredOutputsParams.""" + # Generate structural tag + structural_tag = gptoss_parser.prepare_structured_tag( + None, tool_server_with_python + ) + + # Create StructuredOutputsParams + params = StructuredOutputsParams(structural_tag=structural_tag) + + # Verify the tag is properly stored and accessible + assert params.structural_tag == structural_tag + + # Verify the tag is valid JSON + parsed_tag = json.loads(params.structural_tag) + assert parsed_tag["type"] == "structural_tag" + + @pytest.mark.parametrize( + "browser, python, container, expected_tags", + [ + # No tools + (False, False, False, 1), + # Single tool + (True, False, False, 3), + # Multiple tools + (True, True, False, 5), + # All tools + (True, True, True, 7), + ], + ) + def test_tool_server_interaction_flow( + self, gptoss_parser, browser, python, container, expected_tags + ): + """Test the complete tool server interaction flow.""" + + # Create a mock ToolServer + tool_server = Mock(spec=ToolServer) + + # Simulate tool availability based on parameters + tool_server.has_tool = Mock( + side_effect=lambda tool: { + "browser": browser, + "python": python, + "container": container, + }.get(tool, False) + ) + + # Run the parser and verify results + result = gptoss_parser.prepare_structured_tag(None, tool_server) + parsed_result = json.loads(result) + + # Validate number of tags + assert len(parsed_result["format"]["tags"]) == expected_tags + + # Verify tool-specific tags exist for enabled tools + tag_begins = [tag["begin"] for tag in parsed_result["format"]["tags"]] + for tool, enabled in { + "browser": browser, + "python": python, + "container": container, + }.items(): + if enabled: + assert f"<|channel|>commentary to={tool}" in tag_begins + assert f"<|channel|>analysis to={tool}" in tag_begins + + def test_original_tag_preservation(self, gptoss_parser, tool_server_with_python): + """Test that original tags are preserved when provided.""" + original_tag = '{"type": "custom_tag", "data": "preserved"}' + + result = gptoss_parser.prepare_structured_tag( + original_tag, tool_server_with_python + ) + + # Should return original tag unchanged + assert result == original_tag + + @pytest.mark.parametrize( + "tools", + [ + [], + ["browser"], + ["python"], + ["container"], + ["browser", "python"], + ["browser", "container"], + ["python", "container"], + ["browser", "python", "container"], + ], + ) + def test_json_validity_comprehensive(self, gptoss_parser, tools): + """Test JSON validity across all possible tool combinations.""" + + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock(side_effect=lambda tool: tool in tools) + + result = gptoss_parser.prepare_structured_tag(None, tool_server) + + # Should be valid JSON + parsed_result = json.loads(result) + + # Should have correct structure + assert parsed_result["type"] == "structural_tag" + assert "format" in parsed_result + assert "tags" in parsed_result["format"] + assert "triggers" in parsed_result["format"] + + # Tag count should be: 1 (analysis) + 2 * len(tools) + expected_tag_count = 1 + (2 * len(tools)) + assert len(parsed_result["format"]["tags"]) == expected_tag_count + + def test_error_handling_invalid_tool_server(self, gptoss_parser): + """Test error handling with invalid tool server.""" + # Tool server that raises exceptions + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock(side_effect=Exception("Tool server error")) + + # Should handle gracefully and still return a valid tag + with pytest.raises(Exception, match="Tool server error"): + gptoss_parser.prepare_structured_tag(None, tool_server) + + def test_concurrent_requests_isolation(self, gptoss_parser): + """Test that concurrent requests don't interfere with each other.""" + # Simulate concurrent requests with different tool servers + tool_server_1 = Mock(spec=ToolServer) + tool_server_1.has_tool = Mock(side_effect=lambda tool: tool == "python") + + tool_server_2 = Mock(spec=ToolServer) + tool_server_2.has_tool = Mock(side_effect=lambda tool: tool == "browser") + + # Generate tags concurrently + result_1 = gptoss_parser.prepare_structured_tag(None, tool_server_1) + result_2 = gptoss_parser.prepare_structured_tag(None, tool_server_2) + + # Parse results + parsed_1 = json.loads(result_1) + parsed_2 = json.loads(result_2) + + # Verify they have different tool configurations + tags_1 = [tag["begin"] for tag in parsed_1["format"]["tags"]] + tags_2 = [tag["begin"] for tag in parsed_2["format"]["tags"]] + + # Result 1 should have python tags + assert "<|channel|>commentary to=python" in tags_1 + assert "<|channel|>commentary to=browser" not in tags_1 + + # Result 2 should have browser tags + assert "<|channel|>commentary to=browser" in tags_2 + assert "<|channel|>commentary to=python" not in tags_2 + + def test_tag_format_consistency(self, gptoss_parser): + """Test that all generated tags follow consistent format.""" + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock( + side_effect=lambda tool: tool in ["python", "browser"] + ) + + result = gptoss_parser.prepare_structured_tag(None, tool_server) + parsed_result = json.loads(result) + + # Verify all tags have required fields + for tag in parsed_result["format"]["tags"]: + assert "begin" in tag + assert "content" in tag + assert "end" in tag + assert tag["content"]["type"] == "any_text" + assert tag["end"] == "<|end|>" + + # Verify begin format + assert tag["begin"].startswith("<|channel|>") + + def test_trigger_configuration(self, gptoss_parser): + """Test trigger configuration for different tool setups.""" + # Test with no tools + result_no_tools = gptoss_parser.prepare_structured_tag(None, None) + parsed_no_tools = json.loads(result_no_tools) + assert parsed_no_tools["format"]["triggers"] == ["<|channel|>analysis"] + + # Test with tools + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock(side_effect=lambda tool: tool == "python") + + result_with_tools = gptoss_parser.prepare_structured_tag(None, tool_server) + parsed_with_tools = json.loads(result_with_tools) + + expected_triggers = ["<|channel|>analysis", "<|channel|>commentary to="] + assert set(parsed_with_tools["format"]["triggers"]) == set(expected_triggers) diff --git a/tests/entrypoints/openai/test_lora_adapters.py b/tests/entrypoints/openai/test_lora_adapters.py index 674e14e4f5c1..c74f805961bc 100644 --- a/tests/entrypoints/openai/test_lora_adapters.py +++ b/tests/entrypoints/openai/test_lora_adapters.py @@ -23,11 +23,6 @@ {"r": 1024}, "is greater than max_lora_rank", ), - ( - "test_bias", - {"bias": "all"}, - "Adapter bias cannot be used without bias_enabled", - ), ("test_dora", {"use_dora": True}, "does not yet support DoRA"), ( "test_modules_to_save", diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index 2a15848ba447..a85418d5b5f4 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -4,7 +4,6 @@ from contextlib import suppress from dataclasses import dataclass, field from http import HTTPStatus -from typing import Optional from unittest.mock import AsyncMock, MagicMock import pytest @@ -38,13 +37,13 @@ class MockModelConfig: trust_remote_code: bool = False tokenizer_mode: str = "auto" max_model_len: int = 100 - tokenizer_revision: Optional[str] = None + tokenizer_revision: str | None = None multimodal_config: MultiModalConfig = field(default_factory=MultiModalConfig) hf_config: MockHFConfig = field(default_factory=MockHFConfig) - logits_processor_pattern: Optional[str] = None - diff_sampling_param: Optional[dict] = None + logits_processor_pattern: str | None = None + diff_sampling_param: dict | None = None allowed_local_media_path: str = "" - allowed_media_domains: Optional[list[str]] = None + allowed_media_domains: list[str] | None = None encoder_config = None generation_config: str = "auto" skip_tokenizer_init: bool = False @@ -56,7 +55,7 @@ def get_diff_sampling_param(self): class MockLoRAResolver(LoRAResolver): async def resolve_lora( self, base_model_name: str, lora_name: str - ) -> Optional[LoRARequest]: + ) -> LoRARequest | None: if lora_name == "test-lora": return LoRARequest( lora_name="test-lora", @@ -113,15 +112,17 @@ async def mock_generate(*args, **kwargs): mock_engine.generate.reset_mock() mock_engine.add_lora.reset_mock() - mock_model_config = MockModelConfig() + mock_engine.model_config = MockModelConfig() + mock_engine.processor = MagicMock() + mock_engine.io_processor = MagicMock() + models = OpenAIServingModels( engine_client=mock_engine, base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config, ) serving_completion = OpenAIServingCompletion( - mock_engine, mock_model_config, models, request_logger=None + mock_engine, models, request_logger=None ) serving_completion._process_inputs = AsyncMock( diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 6b00dde494d1..dbcec9d31fc9 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -18,10 +18,18 @@ from ...utils import RemoteOpenAIServer -MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +MODELS = { + "text": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "multimodal": "HuggingFaceTB/SmolVLM-256M-Instruct", +} PREV_MINOR_VERSION = version._prev_minor_version() +@pytest.fixture(scope="module", params=list(MODELS.keys())) +def model_key(request): + yield request.param + + @pytest.fixture(scope="module") def default_server_args(): return [ @@ -45,11 +53,12 @@ def default_server_args(): f"--show-hidden-metrics-for-version={PREV_MINOR_VERSION}", ], ) -def server(default_server_args, request): +def server(model_key, default_server_args, request): if request.param: default_server_args.append(request.param) - with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + model_name = MODELS[model_key] + with RemoteOpenAIServer(model_name, default_server_args) as remote_server: yield remote_server @@ -60,64 +69,70 @@ async def client(server): _PROMPT = "Hello my name is Robert and I love magic" -tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) -_TOKENIZED_PROMPT = tokenizer(_PROMPT)["input_ids"] - -_NUM_REQUESTS = 10 -_NUM_PROMPT_TOKENS_PER_REQUEST = len(_TOKENIZED_PROMPT) -_NUM_GENERATION_TOKENS_PER_REQUEST = 10 - -# {metric_family: [(suffix, expected_value)]} -EXPECTED_VALUES = { - "vllm:time_to_first_token_seconds": [("_count", _NUM_REQUESTS)], - "vllm:time_per_output_token_seconds": [ - ("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1)) - ], - "vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_prompt_tokens": [ - ("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS), - ], - "vllm:request_generation_tokens": [ - ("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS), - ], - "vllm:request_params_n": [("_count", _NUM_REQUESTS)], - "vllm:request_params_max_tokens": [ - ("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS), - ], - "vllm:iteration_tokens_total": [ - ( - "_sum", - _NUM_REQUESTS - * (_NUM_PROMPT_TOKENS_PER_REQUEST + _NUM_GENERATION_TOKENS_PER_REQUEST), - ), - ("_count", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), - ], - "vllm:prompt_tokens": [("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)], - "vllm:generation_tokens": [ - ("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST) - ], - "vllm:request_success": [("_total", _NUM_REQUESTS)], -} +_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + + +def _get_expected_values(num_requests: int, prompt_ids: list[int], max_tokens: int): + num_prompt_tokens = len(prompt_ids) + + # {metric_family: [(suffix, expected_value)]} + return { + "vllm:time_to_first_token_seconds": [("_count", num_requests)], + "vllm:time_per_output_token_seconds": [ + ("_count", num_requests * (max_tokens - 1)) + ], + "vllm:e2e_request_latency_seconds": [("_count", num_requests)], + "vllm:request_queue_time_seconds": [("_count", num_requests)], + "vllm:request_inference_time_seconds": [("_count", num_requests)], + "vllm:request_prefill_time_seconds": [("_count", num_requests)], + "vllm:request_decode_time_seconds": [("_count", num_requests)], + "vllm:request_prompt_tokens": [ + ("_sum", num_requests * num_prompt_tokens), + ("_count", num_requests), + ], + "vllm:request_generation_tokens": [ + ("_sum", num_requests * max_tokens), + ("_count", num_requests), + ], + "vllm:request_params_n": [("_count", num_requests)], + "vllm:request_params_max_tokens": [ + ("_sum", num_requests * max_tokens), + ("_count", num_requests), + ], + "vllm:iteration_tokens_total": [ + ( + "_sum", + num_requests * (num_prompt_tokens + max_tokens), + ), + ("_count", num_requests * max_tokens), + ], + "vllm:prompt_tokens": [("_total", num_requests * num_prompt_tokens)], + "vllm:generation_tokens": [("_total", num_requests * max_tokens)], + "vllm:request_success": [("_total", num_requests)], + } @pytest.mark.asyncio async def test_metrics_counts( server: RemoteOpenAIServer, client: openai.AsyncClient, + model_key: str, ): - for _ in range(_NUM_REQUESTS): + if model_key == "multimodal": + pytest.skip("Unnecessary test") + + model_name = MODELS[model_key] + tokenizer = AutoTokenizer.from_pretrained(model_name) + prompt_ids = tokenizer.encode(_PROMPT) + num_requests = 10 + max_tokens = 10 + + for _ in range(num_requests): # sending a request triggers the metrics to be logged. await client.completions.create( - model=MODEL_NAME, - prompt=_TOKENIZED_PROMPT, - max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST, + model=model_name, + prompt=prompt_ids, + max_tokens=max_tokens, ) response = requests.get(server.url_for("metrics")) @@ -125,8 +140,9 @@ async def test_metrics_counts( assert response.status_code == HTTPStatus.OK # Loop over all expected metric_families - for metric_family, suffix_values_list in EXPECTED_VALUES.items(): - if (metric_family not in EXPECTED_METRICS_V1) or ( + expected_values = _get_expected_values(num_requests, prompt_ids, max_tokens) + for metric_family, suffix_values_list in expected_values.items(): + if metric_family not in EXPECTED_METRICS_V1 or ( not server.show_hidden_metrics and metric_family in HIDDEN_DEPRECATED_METRICS ): @@ -217,6 +233,11 @@ async def test_metrics_counts( "vllm:request_decode_time_seconds_count", ] +EXPECTED_METRICS_MM = [ + "vllm:mm_cache_queries", + "vllm:mm_cache_hits", +] + HIDDEN_DEPRECATED_METRICS: list[str] = [ "vllm:gpu_cache_usage_perc", "vllm:gpu_prefix_cache_queries", @@ -231,19 +252,43 @@ async def test_metrics_counts( async def test_metrics_exist( server: RemoteOpenAIServer, client: openai.AsyncClient, + model_key: str, ): + model_name = MODELS[model_key] + # sending a request triggers the metrics to be logged. - await client.completions.create( - model=MODEL_NAME, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0, - ) + if model_key == "text": + await client.completions.create( + model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0, + ) + else: + await client.chat.completions.create( + model=model_name, + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": _IMAGE_URL}}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ], + max_tokens=5, + temperature=0.0, + ) response = requests.get(server.url_for("metrics")) assert response.status_code == HTTPStatus.OK - for metric in EXPECTED_METRICS_V1: + expected_metrics = EXPECTED_METRICS_V1 + if model_key == "multimodal": + # NOTE: Don't use in-place assignment + expected_metrics = expected_metrics + EXPECTED_METRICS_MM + + for metric in expected_metrics: if metric in HIDDEN_DEPRECATED_METRICS and not server.show_hidden_metrics: continue assert metric in response.text @@ -253,9 +298,14 @@ async def test_metrics_exist( async def test_abort_metrics_reset( server: RemoteOpenAIServer, client: openai.AsyncClient, + model_key: str, ): + model_name = MODELS[model_key] + tokenizer = AutoTokenizer.from_pretrained(model_name) + prompt_ids = tokenizer.encode(_PROMPT) + running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api( - server + server, ) # Expect no running requests or kvcache usage @@ -268,8 +318,8 @@ async def test_abort_metrics_reset( for _ in range(3): task = asyncio.create_task( client.completions.create( - model=MODEL_NAME, - prompt=_TOKENIZED_PROMPT, + model=model_name, + prompt=prompt_ids, max_tokens=100, # Long generation to give time to abort temperature=0.0, ) @@ -281,7 +331,7 @@ async def test_abort_metrics_reset( # Check that we have running requests running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api( - server + server, ) # Expect running requests and kvcache usage diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index 3d0885414b24..cd5661e5739f 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import io +from unittest.mock import Mock # imports for structured outputs tests import openai @@ -10,7 +11,8 @@ import regex as re import torch -from vllm.entrypoints.renderer import BaseRenderer +from vllm.config import ModelConfig +from vllm.entrypoints.renderer import CompletionRenderer from ...utils import RemoteOpenAIServer @@ -59,6 +61,10 @@ async def test_out_of_vocab_token_ids(): def test_load_prompt_embeds( dtype: torch.dtype, layout: torch.layout, seq_len: int, hidden_size: int ): + model_config = Mock(spec=ModelConfig) + model_config.enable_prompt_embeds = True + renderer = CompletionRenderer(model_config, tokenizer=None) + # construct arbitrary tensors of various dtypes, layouts, and sizes. # We need to check against different layouts to make sure that if a user # uses sparse tensors to reduce the transmission size of prompt embeddings, @@ -83,7 +89,7 @@ def test_load_prompt_embeds( buffer.seek(0) encoded_tensor = pybase64.b64encode(buffer.getvalue()) - loaded_prompt_embeds = BaseRenderer.load_prompt_embeds(encoded_tensor) + loaded_prompt_embeds = renderer.load_prompt_embeds(encoded_tensor) assert len(loaded_prompt_embeds) == 1 loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"] assert loaded_tensor.device.type == "cpu" @@ -91,3 +97,22 @@ def test_load_prompt_embeds( torch.testing.assert_close( loaded_tensor, tensor.to("cpu").to_dense(), equal_nan=True ) + + +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("seq_len", [2]) +@pytest.mark.parametrize("hidden_size", [2]) +def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: int): + model_config = Mock(spec=ModelConfig) + model_config.enable_prompt_embeds = False + renderer = CompletionRenderer(model_config, tokenizer=None) + + tensor = torch.randn((seq_len, hidden_size), dtype=dtype) + + buffer = io.BytesIO() + torch.save(tensor, buffer) + buffer.seek(0) + encoded_tensor = pybase64.b64encode(buffer.getvalue()) + + with pytest.raises(ValueError, match="--enable-prompt-embeds"): + renderer.load_prompt_embeds(encoded_tensor) diff --git a/tests/entrypoints/openai/test_protocol.py b/tests/entrypoints/openai/test_protocol.py new file mode 100644 index 000000000000..e9b1cfb58b50 --- /dev/null +++ b/tests/entrypoints/openai/test_protocol.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from openai_harmony import ( + Message, +) + +from vllm.entrypoints.openai.protocol import serialize_message, serialize_messages + + +def test_serialize_message() -> None: + dict_value = {"a": 1, "b": "2"} + assert serialize_message(dict_value) == dict_value + + msg_value = { + "role": "assistant", + "name": None, + "content": [{"type": "text", "text": "Test 1"}], + "channel": "analysis", + } + msg = Message.from_dict(msg_value) + assert serialize_message(msg) == msg_value + + +def test_serialize_messages() -> None: + assert serialize_messages(None) is None + assert serialize_messages([]) is None + + dict_value = {"a": 3, "b": "4"} + msg_value = { + "role": "assistant", + "name": None, + "content": [{"type": "text", "text": "Test 2"}], + "channel": "analysis", + } + msg = Message.from_dict(msg_value) + assert serialize_messages([msg, dict_value]) == [msg_value, dict_value] diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py index 57d88f84d251..4251d06435c1 100644 --- a/tests/entrypoints/openai/test_response_api_with_harmony.py +++ b/tests/entrypoints/openai/test_response_api_with_harmony.py @@ -16,6 +16,22 @@ MODEL_NAME = "openai/gpt-oss-20b" +GET_WEATHER_SCHEMA = { + "type": "function", + "name": "get_weather", + "description": "Get current temperature for provided coordinates in celsius.", # noqa + "parameters": { + "type": "object", + "properties": { + "latitude": {"type": "number"}, + "longitude": {"type": "number"}, + }, + "required": ["latitude", "longitude"], + "additionalProperties": False, + }, + "strict": True, +} + @pytest.fixture(scope="module") def server(): @@ -305,6 +321,54 @@ async def test_streaming_types(client: OpenAI, model_name: str): assert len(stack_of_event_types) == 0 +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_function_calling_with_streaming_types(client: OpenAI, model_name: str): + # this links the "done" type with the "start" type + # so every "done" type should have a corresponding "start" type + # and every open block should be closed by the end of the stream + pairs_of_event_types = { + "response.completed": "response.created", + "response.output_item.done": "response.output_item.added", + "response.output_text.done": "response.output_text.delta", + "response.reasoning_text.done": "response.reasoning_text.delta", + "response.reasoning_part.done": "response.reasoning_part.added", + "response.function_call_arguments.done": "response.function_call_arguments.delta", # noqa + } + + tools = [GET_WEATHER_SCHEMA] + input_list = [ + { + "role": "user", + "content": "What's the weather like in Paris today?", + } + ] + stream_response = await client.responses.create( + model=model_name, + input=input_list, + tools=tools, + stream=True, + ) + + stack_of_event_types = [] + async for event in stream_response: + if event.type == "response.created": + stack_of_event_types.append(event.type) + elif event.type == "response.completed": + assert stack_of_event_types[-1] == pairs_of_event_types[event.type] + stack_of_event_types.pop() + if event.type.endswith("added"): + stack_of_event_types.append(event.type) + elif event.type.endswith("delta"): + if stack_of_event_types[-1] == event.type: + continue + stack_of_event_types.append(event.type) + elif event.type.endswith("done"): + assert stack_of_event_types[-1] == pairs_of_event_types[event.type] + stack_of_event_types.pop() + assert len(stack_of_event_types) == 0 + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("background", [True, False]) @@ -483,23 +547,7 @@ def call_function(name, args): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_function_calling(client: OpenAI, model_name: str): - tools = [ - { - "type": "function", - "name": "get_weather", - "description": "Get current temperature for provided coordinates in celsius.", # noqa - "parameters": { - "type": "object", - "properties": { - "latitude": {"type": "number"}, - "longitude": {"type": "number"}, - }, - "required": ["latitude", "longitude"], - "additionalProperties": False, - }, - "strict": True, - } - ] + tools = [GET_WEATHER_SCHEMA] response = await client.responses.create( model=model_name, @@ -565,21 +613,7 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str): }, "strict": True, }, - { - "type": "function", - "name": "get_weather", - "description": "Get current temperature for provided coordinates in celsius.", # noqa - "parameters": { - "type": "object", - "properties": { - "latitude": {"type": "number"}, - "longitude": {"type": "number"}, - }, - "required": ["latitude", "longitude"], - "additionalProperties": False, - }, - "strict": True, - }, + GET_WEATHER_SCHEMA, ] response = await client.responses.create( @@ -643,23 +677,7 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_function_calling_required(client: OpenAI, model_name: str): - tools = [ - { - "type": "function", - "name": "get_weather", - "description": "Get current temperature for provided coordinates in celsius.", # noqa - "parameters": { - "type": "object", - "properties": { - "latitude": {"type": "number"}, - "longitude": {"type": "number"}, - }, - "required": ["latitude", "longitude"], - "additionalProperties": False, - }, - "strict": True, - } - ] + tools = [GET_WEATHER_SCHEMA] with pytest.raises(BadRequestError): await client.responses.create( @@ -689,23 +707,7 @@ async def test_system_message_with_tools(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_function_calling_full_history(client: OpenAI, model_name: str): - tools = [ - { - "type": "function", - "name": "get_weather", - "description": "Get current temperature for provided coordinates in celsius.", # noqa - "parameters": { - "type": "object", - "properties": { - "latitude": {"type": "number"}, - "longitude": {"type": "number"}, - }, - "required": ["latitude", "longitude"], - "additionalProperties": False, - }, - "strict": True, - } - ] + tools = [GET_WEATHER_SCHEMA] input_messages = [ {"role": "user", "content": "What's the weather like in Paris today?"} @@ -745,6 +747,74 @@ async def test_function_calling_full_history(client: OpenAI, model_name: str): assert response_2.output_text is not None +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_function_calling_with_stream(client: OpenAI, model_name: str): + tools = [GET_WEATHER_SCHEMA] + input_list = [ + { + "role": "user", + "content": "What's the weather like in Paris today?", + } + ] + stream_response = await client.responses.create( + model=model_name, + input=input_list, + tools=tools, + stream=True, + ) + assert stream_response is not None + final_tool_calls = {} + final_tool_calls_named = {} + async for event in stream_response: + if event.type == "response.output_item.added": + if event.item.type != "function_call": + continue + final_tool_calls[event.output_index] = event.item + final_tool_calls_named[event.item.name] = event.item + elif event.type == "response.function_call_arguments.delta": + index = event.output_index + tool_call = final_tool_calls[index] + if tool_call: + tool_call.arguments += event.delta + final_tool_calls_named[tool_call.name] = tool_call + elif event.type == "response.function_call_arguments.done": + assert event.arguments == final_tool_calls_named[event.name].arguments + for tool_call in final_tool_calls.values(): + if ( + tool_call + and tool_call.type == "function_call" + and tool_call.name == "get_weather" + ): + args = json.loads(tool_call.arguments) + result = call_function(tool_call.name, args) + input_list += [tool_call] + break + assert result is not None + response = await client.responses.create( + model=model_name, + input=input_list + + [ + { + "type": "function_call_output", + "call_id": tool_call.call_id, + "output": str(result), + } + ], + tools=tools, + stream=True, + ) + assert response is not None + async for event in response: + # check that no function call events in the stream + assert event.type != "response.function_call_arguments.delta" + assert event.type != "response.function_call_arguments.done" + # check that the response contains output text + if event.type == "response.completed": + assert len(event.response.output) > 0 + assert event.response.output_text is not None + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_output_messages_enabled(client: OpenAI, model_name: str, server): diff --git a/tests/entrypoints/openai/test_responses_function_call_parsing.py b/tests/entrypoints/openai/test_responses_function_call_parsing.py new file mode 100644 index 000000000000..3c5a11c867eb --- /dev/null +++ b/tests/entrypoints/openai/test_responses_function_call_parsing.py @@ -0,0 +1,330 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Test function call parsing in ResponsesRequest.""" + +import json + +import pytest +from openai.types.responses import ResponseFunctionToolCall + +from vllm.entrypoints.openai.protocol import ResponsesRequest + + +def test_function_call_dict_converted_to_object(): + """Test that function_call dictionaries are correctly parsed into + ResponseFunctionToolCall objects.""" + # Create a request with function_call as dict + request_data = { + "model": "gpt-oss", + "input": [ + { + "type": "function_call", + "call_id": "fc_123", + "name": "get_weather", + "arguments": '{"location": "Boston", "unit": "celsius"}', + } + ], + } + + request = ResponsesRequest(**request_data) + + # Verify the input item is now a ResponseFunctionToolCall object + assert len(request.input) == 1 + assert isinstance(request.input[0], ResponseFunctionToolCall) + assert request.input[0].call_id == "fc_123" + assert request.input[0].name == "get_weather" + assert request.input[0].arguments == '{"location": "Boston", "unit": "celsius"}' + + +def test_direct_function_call_object_preservation(): + """Test that ResponseFunctionToolCall objects passed directly are preserved.""" + # Create a request with ResponseFunctionToolCall object + function_call = ResponseFunctionToolCall( + type="function_call", + call_id="fc_456", + name="get_stock_price", + arguments='{"symbol": "AAPL"}', + ) + + request_data = {"model": "gpt-oss", "input": [function_call]} + + request = ResponsesRequest(**request_data) + + # Verify the object is preserved + assert len(request.input) == 1 + assert request.input[0] is function_call + + +def test_mixed_input_types_with_function_calls(): + """Test parsing with mixed input types including function calls.""" + + request_data = { + "model": "gpt-oss", + "input": [ + # Valid Message type + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "What's the weather?"}], + }, + # Function call that should be parsed + { + "type": "function_call", + "call_id": "fc_789", + "name": "check_weather", + "arguments": '{"location": "NYC"}', + }, + # Another function call + { + "type": "function_call", + "call_id": "fc_790", + "name": "get_time", + "arguments": "{}", + }, + ], + } + + request = ResponsesRequest(**request_data) + + # Verify mixed types are handled correctly + assert len(request.input) == 3 + # First item should be validated as Message + assert request.input[0]["type"] == "message" + # Second item should be parsed to ResponseFunctionToolCall + assert isinstance(request.input[1], ResponseFunctionToolCall) + assert request.input[1].call_id == "fc_789" + assert request.input[1].name == "check_weather" + # Third item should also be parsed to ResponseFunctionToolCall + assert isinstance(request.input[2], ResponseFunctionToolCall) + assert request.input[2].call_id == "fc_790" + assert request.input[2].name == "get_time" + + +def test_function_call_with_complex_arguments(): + """Test parsing function calls with complex nested arguments.""" + complex_args = { + "query": "weather forecast", + "filters": { + "location": {"city": "San Francisco", "state": "CA"}, + "timeRange": {"start": "2024-01-01", "end": "2024-01-07"}, + "metrics": ["temperature", "humidity", "precipitation"], + }, + "options": {"format": "detailed", "includeAlerts": True}, + } + + request_data = { + "model": "gpt-oss", + "input": [ + { + "type": "function_call", + "call_id": "fc_complex", + "name": "advanced_weather_query", + "arguments": json.dumps(complex_args), + } + ], + } + + request = ResponsesRequest(**request_data) + + # Verify complex arguments are preserved correctly + assert len(request.input) == 1 + assert isinstance(request.input[0], ResponseFunctionToolCall) + assert request.input[0].call_id == "fc_complex" + assert request.input[0].name == "advanced_weather_query" + + # Parse the arguments back to verify they're intact + parsed_args = json.loads(request.input[0].arguments) + assert parsed_args == complex_args + + +def test_invalid_function_call_fallback(): + """Test that invalid function call dictionaries fall back gracefully.""" + # Missing required field 'call_id' + request_data = { + "model": "gpt-oss", + "input": [ + {"type": "function_call", "name": "incomplete_function", "arguments": "{}"} + ], + } + + # This should not raise an error during model creation + # The validator should keep the original dict and let Pydantic + # handle validation + with pytest.raises(ValueError): + # Pydantic should raise a validation error for the invalid structure + ResponsesRequest(**request_data) + + +def test_string_input_not_affected(): + """Test that string input is not affected by the validator.""" + request_data = {"model": "gpt-oss", "input": "This is a simple string input"} + + request = ResponsesRequest(**request_data) + + # Verify string input remains unchanged + assert request.input == "This is a simple string input" + + +def test_empty_list_input(): + """Test that empty list input is handled correctly.""" + request_data = {"model": "gpt-oss", "input": []} + + request = ResponsesRequest(**request_data) + + # Verify empty list is preserved + assert request.input == [] + + +def test_function_call_output_not_affected(): + """Test that FunctionCallOutput is not affected by the function_call parsing.""" + + # Test with FunctionCallOutput as dict (should not be parsed) + request_data = { + "model": "gpt-oss", + "input": [ + { + "type": "function_call_output", + "call_id": "fc_output_123", + "output": "The weather in Boston is 72°F and sunny.", + } + ], + } + + request = ResponsesRequest(**request_data) + + # FunctionCallOutput should remain as dict (not converted to an object) + assert len(request.input) == 1 + assert isinstance(request.input[0], dict) + assert request.input[0]["type"] == "function_call_output" + assert request.input[0]["call_id"] == "fc_output_123" + assert request.input[0]["output"] == "The weather in Boston is 72°F and sunny." + + +def test_mixed_function_call_and_output(): + """Test that function_call is parsed while function_call_output is preserved.""" + request_data = { + "model": "gpt-oss", + "input": [ + # This should be parsed to ResponseFunctionToolCall + { + "type": "function_call", + "call_id": "fc_call_456", + "name": "get_weather", + "arguments": '{"location": "NYC"}', + }, + # This should remain as dict + { + "type": "function_call_output", + "call_id": "fc_call_456", + "output": "NYC weather is 68°F with light rain", + }, + ], + } + + request = ResponsesRequest(**request_data) + + assert len(request.input) == 2 + + # First item should be parsed to ResponseFunctionToolCall + assert isinstance(request.input[0], ResponseFunctionToolCall) + assert request.input[0].call_id == "fc_call_456" + assert request.input[0].name == "get_weather" + + # Second item should remain as dict (FunctionCallOutput) + assert isinstance(request.input[1], dict) + assert request.input[1]["type"] == "function_call_output" + assert request.input[1]["call_id"] == "fc_call_456" + assert request.input[1]["output"] == "NYC weather is 68°F with light rain" + + +def test_function_call_validation_failure_logs_debug(caplog): + """Test that validation failures are logged at debug level.""" + from unittest.mock import patch + + request_data = { + "model": "gpt-oss", + "input": [ + { + "type": "function_call", + "name": "incomplete_function", + "arguments": "{}", # Missing call_id + } + ], + } + + # Mock the logger to verify debug was called + with patch("vllm.entrypoints.openai.protocol.logger") as mock_logger: + with pytest.raises(ValueError): + ResponsesRequest(**request_data) + + # Verify debug was called with expected message + mock_logger.debug.assert_called_once() + call_args = mock_logger.debug.call_args[0][0] + assert "Failed to parse function_call" in call_args + + +def test_validator_handles_iterator_input(): + """Test that validator can handle ValidatorIterator input (Pydantic internal).""" + + # This test simulates when Pydantic passes a ValidatorIterator instead of a list + # This happened with complex nested structures containing reasoning + function_call + + # Create test data that would normally be a list + test_input_items = [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "Test"}], + }, + { + "type": "reasoning", + "id": "rs_1", + "summary": [{"type": "summary_text", "text": "Test reasoning"}], + "content": [{"type": "reasoning_text", "text": "Test content"}], + }, + { + "type": "function_call", + "call_id": "call_1", + "name": "test_function", + "arguments": '{"test": "value"}', + "id": "fc_1", + }, + ] + + # Mock data where input is an iterator (simulates Pydantic ValidatorIterator) + mock_data = { + "model": "test-model", + "input": iter(test_input_items), # Iterator instead of list + } + + # This should NOT raise an error with the fixed validator + try: + request = ResponsesRequest(**mock_data) + + # Verify the validator processed the data correctly + assert len(request.input) == 3 + + # Verify function_call was converted to ResponseFunctionToolCall object + function_call_item = None + for item in request.input: + if isinstance(item, ResponseFunctionToolCall): + function_call_item = item + break + + assert function_call_item is not None + assert function_call_item.call_id == "call_1" + assert function_call_item.name == "test_function" + + except Exception as e: + pytest.fail(f"Validator should handle iterator input, but failed with: {e}") + + +def test_validator_handles_empty_iterator(): + """Test validator handles empty iterator gracefully.""" + mock_data = { + "model": "test-model", + "input": iter([]), # Empty iterator + } + + request = ResponsesRequest(**mock_data) + assert request.input == [] diff --git a/tests/entrypoints/openai/test_run_batch.py b/tests/entrypoints/openai/test_run_batch.py index d31dadf90679..2f678a0535cc 100644 --- a/tests/entrypoints/openai/test_run_batch.py +++ b/tests/entrypoints/openai/test_run_batch.py @@ -9,22 +9,28 @@ from vllm.entrypoints.openai.protocol import BatchRequestOutput -# ruff: noqa: E501 -INPUT_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} -{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} - -{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NonExistModel", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} -{"custom_id": "request-4", "method": "POST", "url": "/bad_url", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} -{"custom_id": "request-5", "method": "POST", "url": "/v1/chat/completions", "body": {"stream": "True", "model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}""" - -INVALID_INPUT_BATCH = """{"invalid_field": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} -{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}""" +MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" -INPUT_EMBEDDING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}} -{"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are an unhelpful assistant."}} - -{"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "Hello world!"}} -{"custom_id": "request-4", "method": "POST", "url": "/v1/embeddings", "body": {"model": "NonExistModel", "input": "Hello world!"}}""" +# ruff: noqa: E501 +INPUT_BATCH = ( + '{{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {{"model": "{0}", "messages": [{{"role": "system", "content": "You are a helpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}\n' + '{{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {{"model": "{0}", "messages": [{{"role": "system", "content": "You are an unhelpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}\n' + '{{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {{"model": "NonExistModel", "messages": [{{"role": "system", "content": "You are an unhelpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}\n' + '{{"custom_id": "request-4", "method": "POST", "url": "/bad_url", "body": {{"model": "{0}", "messages": [{{"role": "system", "content": "You are an unhelpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}\n' + '{{"custom_id": "request-5", "method": "POST", "url": "/v1/chat/completions", "body": {{"stream": "True", "model": "{0}", "messages": [{{"role": "system", "content": "You are an unhelpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}' +).format(MODEL_NAME) + +INVALID_INPUT_BATCH = ( + '{{"invalid_field": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {{"model": "{0}", "messages": [{{"role": "system", "content": "You are a helpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}\n' + '{{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {{"model": "{0}", "messages": [{{"role": "system", "content": "You are an unhelpful assistant."}},{{"role": "user", "content": "Hello world!"}}],"max_tokens": 1000}}}}' +).format(MODEL_NAME) + +INPUT_EMBEDDING_BATCH = ( + '{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}}\n' + '{"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are an unhelpful assistant."}}\n' + '{"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "Hello world!"}}\n' + '{"custom_id": "request-4", "method": "POST", "url": "/v1/embeddings", "body": {"model": "NonExistModel", "input": "Hello world!"}}' +) INPUT_SCORE_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} {"custom_id": "request-2", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}""" @@ -33,6 +39,9 @@ {"custom_id": "request-2", "method": "POST", "url": "/v1/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}} {"custom_id": "request-2", "method": "POST", "url": "/v2/rerank", "body": {"model": "BAAI/bge-reranker-v2-m3", "query": "What is the capital of France?", "documents": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}""" +INPUT_REASONING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "Qwen/Qwen3-0.6B", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Solve this math problem: 2+2=?"}]}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "Qwen/Qwen3-0.6B", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "What is the capital of France?"}]}}""" + def test_empty_file(): with ( @@ -77,7 +86,7 @@ def test_completions(): "-o", output_file.name, "--model", - "NousResearch/Meta-Llama-3-8B-Instruct", + MODEL_NAME, ], ) proc.communicate() @@ -110,7 +119,7 @@ def test_completions_invalid_input(): "-o", output_file.name, "--model", - "NousResearch/Meta-Llama-3-8B-Instruct", + MODEL_NAME, ], ) proc.communicate() @@ -182,3 +191,50 @@ def test_score(input_batch): line_dict = json.loads(line) assert isinstance(line_dict, dict) assert line_dict["error"] is None + + +def test_reasoning_parser(): + """ + Test that reasoning_parser parameter works correctly in run_batch. + """ + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): + input_file.write(INPUT_REASONING_BATCH) + input_file.flush() + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "Qwen/Qwen3-0.6B", + "--reasoning-parser", + "qwen3", + ], + ) + proc.communicate() + proc.wait() + assert proc.returncode == 0, f"{proc=}" + + contents = output_file.read() + for line in contents.strip().split("\n"): + # Ensure that the output format conforms to the openai api. + # Validation should throw if the schema is wrong. + BatchRequestOutput.model_validate_json(line) + + # Ensure that there is no error in the response. + line_dict = json.loads(line) + assert isinstance(line_dict, dict) + assert line_dict["error"] is None + + # Check that reasoning_content is present and not empty + reasoning_content = line_dict["response"]["body"]["choices"][0]["message"][ + "reasoning_content" + ] + assert reasoning_content is not None + assert len(reasoning_content) > 0 diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index abe5a5f4ffc1..d1367b4eeaf6 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -1,16 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from __future__ import annotations - import asyncio from contextlib import suppress from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import Any from unittest.mock import AsyncMock, MagicMock import pytest import pytest_asyncio +from openai import OpenAI from vllm.config.multimodal import MultiModalConfig from vllm.entrypoints.openai.protocol import ChatCompletionRequest @@ -21,9 +19,6 @@ from ...utils import RemoteOpenAIServer -if TYPE_CHECKING: - from openai import OpenAI - GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b" @@ -207,6 +202,132 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, with_tool_parser: ) +@pytest.mark.asyncio +async def test_gpt_oss_tool_message_array_content( + gptoss_client: OpenAI, with_tool_parser: bool +): + """Test that tool messages support both string and array content formats.""" + if not with_tool_parser: + pytest.skip("skip non-tool for array content tests") + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "state": {"type": "string"}, + }, + "required": ["city", "state"], + }, + }, + } + ] + + # Test 1: Tool message with string content + messages_string = [ + {"role": "user", "content": "What's the weather in Paris?"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Paris", "state": "TX"}', + }, + } + ], + }, + {"role": "tool", "content": "The weather in Paris, TX is sunny, 22°C"}, + ] + + response_string = await gptoss_client.chat.completions.create( + model=GPT_OSS_MODEL_NAME, + messages=messages_string, + tools=tools, + temperature=0.0, + ) + + assert response_string is not None + assert response_string.choices[0].message is not None + + # Test 2: Tool message with array content + messages_array = [ + {"role": "user", "content": "What's the weather in Dallas?"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_456", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Dallas", "state": "TX"}', + }, + } + ], + }, + { + "role": "tool", + "content": [ + {"type": "text", "text": "f2e897a7-2705-4337-8193-2a8f57b81618"} + ], + }, + ] + + response_array = await gptoss_client.chat.completions.create( + model=GPT_OSS_MODEL_NAME, + messages=messages_array, + tools=tools, + temperature=0.0, + ) + + assert response_array is not None + assert response_array.choices[0].message is not None + + # Test 3: Tool message with multiple array content items + messages_multi_array = [ + {"role": "user", "content": "Search for information"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_789", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Austin", "state": "TX"}', + }, + } + ], + }, + { + "role": "tool", + "content": [ + {"type": "text", "text": "Weather data: "}, + {"type": "text", "text": "Austin, TX - Partly cloudy, 25°C"}, + {"type": "text", "text": " with 60% humidity"}, + ], + }, + ] + + response_multi_array = await gptoss_client.chat.completions.create( + model=GPT_OSS_MODEL_NAME, + messages=messages_multi_array, + tools=tools, + temperature=0.0, + ) + + assert response_multi_array is not None + assert response_multi_array.choices[0].message is not None + + MODEL_NAME = "openai-community/gpt2" MODEL_NAME_SHORT = "gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" @@ -245,17 +366,13 @@ def get_diff_sampling_param(self): return self.diff_sampling_param or {} -def _build_serving_chat( - engine: AsyncLLM, model_config: MockModelConfig -) -> OpenAIServingChat: +def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: models = OpenAIServingModels( engine_client=engine, base_model_paths=BASE_MODEL_PATHS, - model_config=model_config, ) serving_chat = OpenAIServingChat( engine, - model_config, models, response_role="assistant", chat_template=CHAT_TEMPLATE, @@ -280,18 +397,17 @@ async def _fake_process_inputs( @dataclass class MockEngine: - async def get_model_config(self): - return MockModelConfig() + model_config: MockModelConfig = field(default_factory=MockModelConfig) + processor: MagicMock = field(default_factory=MagicMock) + io_processor: MagicMock = field(default_factory=MagicMock) async def _async_serving_chat_init(): engine = MockEngine() - model_config = await engine.get_model_config() - models = OpenAIServingModels(engine, model_config, BASE_MODEL_PATHS) + models = OpenAIServingModels(engine, BASE_MODEL_PATHS) serving_completion = OpenAIServingChat( engine, - model_config, models, response_role="assistant", chat_template=CHAT_TEMPLATE, @@ -311,8 +427,11 @@ async def test_serving_chat_returns_correct_model_name(): mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.processor = MagicMock() + mock_engine.io_processor = MagicMock() - serving_chat = _build_serving_chat(mock_engine, MockModelConfig()) + serving_chat = _build_serving_chat(mock_engine) messages = [{"role": "user", "content": "what is 1+1?"}] async def return_model_name(*args): @@ -338,8 +457,11 @@ async def test_serving_chat_should_set_correct_max_tokens(): mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.processor = MagicMock() + mock_engine.io_processor = MagicMock() - serving_chat = _build_serving_chat(mock_engine, MockModelConfig()) + serving_chat = _build_serving_chat(mock_engine) req = ChatCompletionRequest( model=MODEL_NAME, @@ -368,9 +490,12 @@ async def test_serving_chat_should_set_correct_max_tokens(): mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False + mock_engine.model_config = mock_model_config + mock_engine.processor = MagicMock() + mock_engine.io_processor = MagicMock() # Initialize the serving chat - serving_chat = _build_serving_chat(mock_engine, mock_model_config) + serving_chat = _build_serving_chat(mock_engine) # Test Case 1: No max_tokens specified in request req = ChatCompletionRequest( @@ -410,9 +535,12 @@ async def test_serving_chat_should_set_correct_max_tokens(): mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False + mock_engine.model_config = mock_model_config + mock_engine.processor = MagicMock() + mock_engine.io_processor = MagicMock() # Initialize the serving chat - serving_chat = _build_serving_chat(mock_engine, mock_model_config) + serving_chat = _build_serving_chat(mock_engine) # Test case 1: No max_tokens specified, defaults to context_window req = ChatCompletionRequest( @@ -453,9 +581,12 @@ async def test_serving_chat_could_load_correct_generation_config(): mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False + mock_engine.model_config = mock_model_config + mock_engine.processor = MagicMock() + mock_engine.io_processor = MagicMock() # Initialize the serving chat - serving_chat = _build_serving_chat(mock_engine, mock_model_config) + serving_chat = _build_serving_chat(mock_engine) req = ChatCompletionRequest( model=MODEL_NAME, @@ -496,8 +627,11 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): mock_engine = MagicMock(spec=AsyncLLM) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False + mock_engine.model_config = mock_model_config + mock_engine.processor = MagicMock() + mock_engine.io_processor = MagicMock() - serving_chat = _build_serving_chat(mock_engine, mock_model_config) + serving_chat = _build_serving_chat(mock_engine) # Test cache_salt req = ChatCompletionRequest( diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py index 0c52270c13af..46d8871441a7 100644 --- a/tests/entrypoints/openai/test_serving_engine.py +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -22,10 +22,12 @@ def serving() -> OpenAIServing: model_config = Mock(spec=ModelConfig) model_config.max_model_len = 32768 models = Mock(spec=OpenAIServingModels) + models.model_config = model_config + models.processor = Mock() + models.io_processor = Mock() serving = OpenAIServing( engine_client=engine_client, - model_config=model_config, models=models, request_logger=None, ) diff --git a/tests/entrypoints/openai/test_serving_models.py b/tests/entrypoints/openai/test_serving_models.py index ed9dedcc6f08..3c022870dba4 100644 --- a/tests/entrypoints/openai/test_serving_models.py +++ b/tests/entrypoints/openai/test_serving_models.py @@ -16,7 +16,7 @@ from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.lora.request import LoRARequest -MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] LORA_LOADING_SUCCESS_MESSAGE = "Success: LoRA adapter '{lora_name}' added successfully." LORA_UNLOADING_SUCCESS_MESSAGE = ( @@ -25,15 +25,17 @@ async def _async_serving_models_init() -> OpenAIServingModels: - mock_model_config = MagicMock(spec=ModelConfig) mock_engine_client = MagicMock(spec=EngineClient) # Set the max_model_len attribute to avoid missing attribute + mock_model_config = MagicMock(spec=ModelConfig) mock_model_config.max_model_len = 2048 + mock_engine_client.model_config = mock_model_config + mock_engine_client.processor = MagicMock() + mock_engine_client.io_processor = MagicMock() serving_models = OpenAIServingModels( engine_client=mock_engine_client, base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config, lora_modules=None, ) await serving_models.init_static_loras() diff --git a/tests/entrypoints/openai/test_serving_responses.py b/tests/entrypoints/openai/test_serving_responses.py index cd7bb06ad320..263b076db183 100644 --- a/tests/entrypoints/openai/test_serving_responses.py +++ b/tests/entrypoints/openai/test_serving_responses.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import AsyncExitStack -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock import pytest import pytest_asyncio @@ -70,11 +70,14 @@ async def serving_responses_instance(self): """Create a real OpenAIServingResponses instance for testing""" # Create minimal mocks for required dependencies engine_client = MagicMock() - engine_client.get_model_config = AsyncMock() model_config = MagicMock() model_config.hf_config.model_type = "test" model_config.get_diff_sampling_param.return_value = {} + engine_client.model_config = model_config + + engine_client.processor = MagicMock() + engine_client.io_processor = MagicMock() models = MagicMock() @@ -83,7 +86,6 @@ async def serving_responses_instance(self): # Create the actual instance instance = OpenAIServingResponses( engine_client=engine_client, - model_config=model_config, models=models, request_logger=None, chat_template=None, @@ -132,18 +134,20 @@ async def serving_responses_instance(self): """Create a real OpenAIServingResponses instance for testing""" # Create minimal mocks for required dependencies engine_client = MagicMock() - engine_client.get_model_config = AsyncMock() model_config = MagicMock() model_config.hf_config.model_type = "test" model_config.get_diff_sampling_param.return_value = {} + engine_client.model_config = model_config + + engine_client.processor = MagicMock() + engine_client.io_processor = MagicMock() models = MagicMock() # Create the actual instance instance = OpenAIServingResponses( engine_client=engine_client, - model_config=model_config, models=models, request_logger=None, chat_template=None, diff --git a/tests/entrypoints/openai/test_shutdown.py b/tests/entrypoints/openai/test_shutdown.py index ff46df81d0ff..d75119cb7b43 100644 --- a/tests/entrypoints/openai/test_shutdown.py +++ b/tests/entrypoints/openai/test_shutdown.py @@ -1,37 +1,93 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import signal +import subprocess +import sys +import time + import openai import pytest -from ...utils import RemoteOpenAIServer +from vllm.utils.network_utils import get_open_port -MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" @pytest.mark.asyncio async def test_shutdown_on_engine_failure(): - # dtype, max-len etc set so that this can run in CI - args = [ - "--dtype", - "bfloat16", - "--max-model-len", - "8192", - "--enforce-eager", - "--max-num-seqs", - "128", - ] - - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - async with remote_server.get_async_client() as client: - with pytest.raises((openai.APIConnectionError, openai.InternalServerError)): - # Asking for lots of prompt logprobs will currently crash the - # engine. This may change in the future when that bug is fixed - prompt = "Hello " * 4000 - await client.completions.create( - model=MODEL_NAME, prompt=prompt, extra_body={"prompt_logprobs": 10} + """Verify that API returns connection error when server process is killed. + + Starts a vLLM server, kills it to simulate a crash, then verifies that + subsequent API calls fail appropriately. + """ + + port = get_open_port() + + proc = subprocess.Popen( + [ + # dtype, max-len etc set so that this can run in CI + sys.executable, + "-m", + "vllm.entrypoints.openai.api_server", + "--model", + MODEL_NAME, + "--dtype", + "bfloat16", + "--max-model-len", + "128", + "--enforce-eager", + "--port", + str(port), + "--gpu-memory-utilization", + "0.05", + "--max-num-seqs", + "2", + "--disable-frontend-multiprocessing", + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + preexec_fn=lambda: signal.signal(signal.SIGINT, signal.SIG_IGN), + ) + + # Wait for server startup + start_time = time.time() + client = openai.AsyncOpenAI( + base_url=f"http://localhost:{port}/v1", + api_key="dummy", + max_retries=0, + timeout=10, + ) + + # Poll until server is ready + while time.time() - start_time < 30: + try: + await client.completions.create( + model=MODEL_NAME, prompt="Hello", max_tokens=1 + ) + break + except Exception: + time.sleep(0.5) + if proc.poll() is not None: + stdout, stderr = proc.communicate(timeout=1) + pytest.fail( + f"Server died during startup. stdout: {stdout}, stderr: {stderr}" ) + else: + proc.terminate() + proc.wait(timeout=5) + pytest.fail("Server failed to start in 30 seconds") + + # Kill server to simulate crash + proc.terminate() + time.sleep(1) + + # Verify API calls now fail + with pytest.raises((openai.APIConnectionError, openai.APIStatusError)): + await client.completions.create( + model=MODEL_NAME, prompt="This should fail", max_tokens=1 + ) - # Now the server should shut down - return_code = remote_server.proc.wait(timeout=8) - assert return_code is not None + return_code = proc.wait(timeout=5) + assert return_code is not None diff --git a/tests/entrypoints/openai/test_video.py b/tests/entrypoints/openai/test_video.py index 4c7d1c14ca17..7ecdac518f97 100644 --- a/tests/entrypoints/openai/test_video.py +++ b/tests/entrypoints/openai/test_video.py @@ -55,22 +55,35 @@ def base64_encoded_video() -> dict[str, str]: } -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) -async def test_single_chat_session_video( - client: openai.AsyncOpenAI, model_name: str, video_url: str +def dummy_messages_from_video_url( + video_urls: str | list[str], + content_text: str = "What's in this video?", ): - messages = [ + if isinstance(video_urls, str): + video_urls = [video_urls] + + return [ { "role": "user", "content": [ - {"type": "video_url", "video_url": {"url": video_url}}, - {"type": "text", "text": "What's in this video?"}, + *( + {"type": "video_url", "video_url": {"url": video_url}} + for video_url in video_urls + ), + {"type": "text", "text": content_text}, ], } ] + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) +async def test_single_chat_session_video( + client: openai.AsyncOpenAI, model_name: str, video_url: str +): + messages = dummy_messages_from_video_url(video_url) + # test single completion chat_completion = await client.chat.completions.create( model=model_name, @@ -137,15 +150,7 @@ async def test_error_on_invalid_video_url_type( async def test_single_chat_session_video_beamsearch( client: openai.AsyncOpenAI, model_name: str, video_url: str ): - messages = [ - { - "role": "user", - "content": [ - {"type": "video_url", "video_url": {"url": video_url}}, - {"type": "text", "text": "What's in this video?"}, - ], - } - ] + messages = dummy_messages_from_video_url(video_url) chat_completion = await client.chat.completions.create( model=model_name, @@ -172,20 +177,9 @@ async def test_single_chat_session_video_base64encoded( video_url: str, base64_encoded_video: dict[str, str], ): - messages = [ - { - "role": "user", - "content": [ - { - "type": "video_url", - "video_url": { - "url": f"data:video/jpeg;base64,{base64_encoded_video[video_url]}" # noqa: E501 - }, - }, - {"type": "text", "text": "What's in this video?"}, - ], - } - ] + messages = dummy_messages_from_video_url( + f"data:video/jpeg;base64,{base64_encoded_video[video_url]}" + ) # test single completion chat_completion = await client.chat.completions.create( @@ -231,20 +225,10 @@ async def test_single_chat_session_video_base64encoded_beamsearch( video_url: str, base64_encoded_video: dict[str, str], ): - messages = [ - { - "role": "user", - "content": [ - { - "type": "video_url", - "video_url": { - "url": f"data:video/jpeg;base64,{base64_encoded_video[video_url]}" # noqa: E501 - }, - }, - {"type": "text", "text": "What's in this video?"}, - ], - } - ] + messages = dummy_messages_from_video_url( + f"data:video/jpeg;base64,{base64_encoded_video[video_url]}" + ) + chat_completion = await client.chat.completions.create( model=model_name, messages=messages, @@ -265,15 +249,7 @@ async def test_single_chat_session_video_base64encoded_beamsearch( async def test_chat_streaming_video( client: openai.AsyncOpenAI, model_name: str, video_url: str ): - messages = [ - { - "role": "user", - "content": [ - {"type": "video_url", "video_url": {"url": video_url}}, - {"type": "text", "text": "What's in this video?"}, - ], - } - ] + messages = dummy_messages_from_video_url(video_url) # test single completion chat_completion = await client.chat.completions.create( @@ -318,18 +294,7 @@ async def test_chat_streaming_video( async def test_multi_video_input( client: openai.AsyncOpenAI, model_name: str, video_urls: list[str] ): - messages = [ - { - "role": "user", - "content": [ - *( - {"type": "video_url", "video_url": {"url": video_url}} - for video_url in video_urls - ), - {"type": "text", "text": "What's in this video?"}, - ], - } - ] + messages = dummy_messages_from_video_url(video_urls) if len(video_urls) > MAXIMUM_VIDEOS: with pytest.raises(openai.BadRequestError): # test multi-video input diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 5a15a352f45c..2a7df08ea3b0 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -34,7 +34,7 @@ ], [ "The image shows a Venn diagram with three over", - "This image shows a Venn diagram with three over", + "The image shows a colorful Venn diagram with", ], [ "This image displays a gradient of colors ranging from", @@ -78,6 +78,27 @@ def base64_encoded_image(local_asset_server) -> dict[str, str]: } +def dummy_messages_from_image_url( + image_urls: str | list[str], + content_text: str = "What's in this image?", +): + if isinstance(image_urls, str): + image_urls = [image_urls] + + return [ + { + "role": "user", + "content": [ + *( + {"type": "image_url", "image_url": {"url": image_url}} + for image_url in image_urls + ), + {"type": "text", "text": content_text}, + ], + } + ] + + def get_hf_prompt_tokens(model_name, content, image_url): processor = AutoProcessor.from_pretrained( model_name, trust_remote_code=True, num_crops=4 @@ -107,15 +128,7 @@ async def test_single_chat_session_image( client: openai.AsyncOpenAI, model_name: str, image_url: str ): content_text = "What's in this image?" - messages = [ - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": image_url}}, - {"type": "text", "text": content_text}, - ], - } - ] + messages = dummy_messages_from_image_url(image_url, content_text) max_completion_tokens = 10 # test single completion @@ -188,15 +201,8 @@ async def test_error_on_invalid_image_url_type( async def test_single_chat_session_image_beamsearch( client: openai.AsyncOpenAI, model_name: str, image_url: str ): - messages = [ - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": image_url}}, - {"type": "text", "text": "What's in this image?"}, - ], - } - ] + content_text = "What's in this image?" + messages = dummy_messages_from_image_url(image_url, content_text) chat_completion = await client.chat.completions.create( model=model_name, @@ -226,20 +232,10 @@ async def test_single_chat_session_image_base64encoded( base64_encoded_image: dict[str, str], ): content_text = "What's in this image?" - messages = [ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}" # noqa: E501 - }, - }, - {"type": "text", "text": content_text}, - ], - } - ] + messages = dummy_messages_from_image_url( + f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}", + content_text, + ) max_completion_tokens = 10 # test single completion @@ -293,20 +289,10 @@ async def test_single_chat_session_image_base64encoded_beamsearch( raw_image_url = TEST_IMAGE_ASSETS[image_idx] expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx] - messages = [ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}" # noqa: E501 - }, - }, - {"type": "text", "text": "What's in this image?"}, - ], - } - ] + messages = dummy_messages_from_image_url( + f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}" + ) + chat_completion = await client.chat.completions.create( model=model_name, messages=messages, @@ -326,15 +312,7 @@ async def test_single_chat_session_image_base64encoded_beamsearch( async def test_chat_streaming_image( client: openai.AsyncOpenAI, model_name: str, image_url: str ): - messages = [ - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": image_url}}, - {"type": "text", "text": "What's in this image?"}, - ], - } - ] + messages = dummy_messages_from_image_url(image_url) # test single completion chat_completion = await client.chat.completions.create( @@ -381,18 +359,7 @@ async def test_chat_streaming_image( async def test_multi_image_input( client: openai.AsyncOpenAI, model_name: str, image_urls: list[str] ): - messages = [ - { - "role": "user", - "content": [ - *( - {"type": "image_url", "image_url": {"url": image_url}} - for image_url in image_urls - ), - {"type": "text", "text": "What's in this image?"}, - ], - } - ] + messages = dummy_messages_from_image_url(image_urls) if len(image_urls) > MAXIMUM_IMAGES: with pytest.raises(openai.BadRequestError): # test multi-image input diff --git a/tests/entrypoints/openai/test_skip_tokenizer.py b/tests/entrypoints/openai/test_vision_embeds.py similarity index 76% rename from tests/entrypoints/openai/test_skip_tokenizer.py rename to tests/entrypoints/openai/test_vision_embeds.py index 6998566c03d0..a6593c5b05e2 100644 --- a/tests/entrypoints/openai/test_skip_tokenizer.py +++ b/tests/entrypoints/openai/test_vision_embeds.py @@ -15,30 +15,7 @@ DTYPE = "float16" -@pytest.fixture(scope="module") -def server(): - args = [ - "--runner", - "pooling", - # use half precision for speed and memory savings in CI environment - "--dtype", - DTYPE, - "--enforce-eager", - "--trust-remote-code", - "--skip-tokenizer-init", - "--max-num-seqs", - "32", - "--model-impl", - "terratorch", - ] - - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_single_request(server: RemoteOpenAIServer, model_name: str): +def _terratorch_dummy_inputs(model_name: str): pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16) location_coords = torch.full((1, 2), 1.0, dtype=torch.float16) @@ -54,7 +31,7 @@ async def test_single_request(server: RemoteOpenAIServer, model_name: str): binary_data = buffer_coord.read() base64_coord_embedding = base64.b64encode(binary_data).decode("utf-8") - prompt = { + return { "model": model_name, "additional_data": {"prompt_token_ids": [1]}, "encoding_format": "base64", @@ -74,12 +51,33 @@ async def test_single_request(server: RemoteOpenAIServer, model_name: str): ], } - # test single pooling - response = requests.post(server.url_for("pooling"), json=prompt) - response.raise_for_status() - output = response.json()["data"][0]["data"] +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_single_request(model_name: str): + args = [ + "--runner", + "pooling", + # use half precision for speed and memory savings in CI environment + "--dtype", + DTYPE, + "--enforce-eager", + "--trust-remote-code", + "--max-num-seqs", + "32", + "--model-impl", + "terratorch", + "--skip-tokenizer-init", + "--enable-mm-embeds", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as server: + prompt = _terratorch_dummy_inputs(model_name) + + # test single pooling + response = requests.post(server.url_for("pooling"), json=prompt) + response.raise_for_status() - np_response = np.frombuffer(base64.b64decode(output), dtype=np.float32) + output = response.json()["data"][0]["data"] - assert len(np_response) == 524288 + np_response = np.frombuffer(base64.b64decode(output), dtype=np.float32) + assert len(np_response) == 524288 diff --git a/tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py new file mode 100644 index 000000000000..224196b9a0b2 --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py @@ -0,0 +1,243 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import MagicMock, patch + +import pytest + +from tests.entrypoints.openai.tool_parsers.utils import ( + run_tool_extraction, + run_tool_extraction_streaming, +) +from vllm.entrypoints.openai.protocol import FunctionCall +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager + +# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1 +SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')" +SIMPLE_FUNCTION_CALL = FunctionCall( + name="get_weather", + arguments='{"city": "San Francisco", "metric": "celsius"}', +) +MORE_TYPES_FUNCTION_OUTPUT = ( + "register_user(name='John Doe', " + "age=37, " + "address={'city': 'San Francisco', 'state': 'CA'}, " + "role=None, " + "passed_test=True, " + "aliases=['John', 'Johnny'])" +) +MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS = ( + "register_user(name='John Doe', " + "age=37, " + "address={'city': 'San Francisco', 'state': 'CA'}, " + "role=null, " + "passed_test=true, " + "aliases=['John', 'Johnny'])" +) +MORE_TYPES_FUNCTION_CALL = FunctionCall( + name="register_user", + arguments='{"name": "John Doe", ' + '"age": 37, ' + '"address": {"city": "San Francisco", "state": "CA"}, ' + '"role": null, ' + '"passed_test": true, ' + '"aliases": ["John", "Johnny"]}', +) +PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()" +PARAMETERLESS_FUNCTION_CALL = FunctionCall( + name="get_weather", + arguments="{}", +) +EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})" +EMPTY_DICT_FUNCTION_CALL = FunctionCall( + name="do_something_cool", + arguments='{"additional_data": {}}', +) +EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])" +EMPTY_LIST_FUNCTION_CALL = FunctionCall( + name="do_something_cool", + arguments='{"steps": []}', +) +ESCAPED_STRING_FUNCTION_OUTPUT = ( + r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')" +) +ESCAPED_STRING_FUNCTION_CALL = FunctionCall( + name="get_weather", + arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}', +) + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_no_tool_call(streaming: bool): + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer) + model_output = "How can I help you today?" + + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) + + assert content == model_output + assert len(tool_calls) == 0 + + +TEST_CASES = [ + pytest.param( + True, + f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>", + [SIMPLE_FUNCTION_CALL], + id="simple_streaming", + ), + pytest.param( + False, + f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>", + [SIMPLE_FUNCTION_CALL], + id="simple_nonstreaming", + ), + pytest.param( + True, + f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>", + [MORE_TYPES_FUNCTION_CALL], + id="more_types_streaming", + ), + pytest.param( + False, + f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>", + [MORE_TYPES_FUNCTION_CALL], + id="more_types_nonstreaming", + ), + pytest.param( + True, + f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>", + [MORE_TYPES_FUNCTION_CALL], + id="more_types_streaming_json_literals", + ), + pytest.param( + False, + f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>", + [MORE_TYPES_FUNCTION_CALL], + id="more_types_nonstreaming_json_literals", + ), + pytest.param( + True, + f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>", + [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_streaming", + ), + pytest.param( + False, + f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>", + [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_nonstreaming", + ), + pytest.param( + True, + f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>", + [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_streaming", + ), + pytest.param( + False, + f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>", + [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_nonstreaming", + ), + pytest.param( + True, + f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>", + [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_streaming", + ), + pytest.param( + False, + f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>", + [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_nonstreaming", + ), + pytest.param( + True, + f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>", + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_streaming", + ), + pytest.param( + False, + f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>", + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_nonstreaming", + ), + pytest.param( + True, + f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>", + [SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], + id="parallel_calls_streaming", + ), + pytest.param( + False, + f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>", + [SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], + id="parallel_calls_nonstreaming", + ), +] + + +@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES) +def test_tool_call( + streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall] +): + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer) + + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) + + assert content is None + assert len(tool_calls) == len(expected_tool_calls) + for actual, expected in zip(tool_calls, expected_tool_calls): + assert actual.type == "function" + assert actual.function == expected + + +def test_streaming_tool_call_with_large_steps(): + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer) + model_output_deltas = [ + "<function_calls>get_weather(city='San", + " Francisco', metric='celsius')\n" + f"{PARAMETERLESS_FUNCTION_OUTPUT}\n" + f"{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>", + ] + + reconstructor = run_tool_extraction_streaming( + tool_parser, model_output_deltas, assert_one_tool_per_delta=False + ) + + assert reconstructor.other_content == "" + assert len(reconstructor.tool_calls) == 3 + assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL + assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL + assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL + + +@pytest.mark.parametrize("streaming", [False]) +def test_regex_timeout_handling(streaming: bool): + """test regex timeout is handled gracefully""" + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer) + + fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2 + + # create a mock regex that raises TimeoutError + mock_regex = MagicMock() + mock_regex.match.side_effect = TimeoutError("Regex timeout") + + with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex): + content, tool_calls = run_tool_extraction( + tool_parser, fake_problematic_input, streaming=streaming + ) + + # should treat as regular text when regex times out + assert content == fake_problematic_input + assert len(tool_calls) == 0 + mock_regex.match.assert_called_once() diff --git a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py index ccd6abbac4c9..d7b4051ea572 100644 --- a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py @@ -210,7 +210,7 @@ def test_streaming_tool_call_with_large_steps(): def test_regex_timeout_handling(streaming: bool): """test regex timeout is handled gracefully""" mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( + tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( mock_tokenizer ) diff --git a/tests/entrypoints/openai/tool_parsers/utils.py b/tests/entrypoints/openai/tool_parsers/utils.py index cfa4d3584e70..7489a406224a 100644 --- a/tests/entrypoints/openai/tool_parsers/utils.py +++ b/tests/entrypoints/openai/tool_parsers/utils.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Union from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, @@ -84,10 +83,10 @@ def append_delta(self, delta: DeltaMessage): def run_tool_extraction( tool_parser: ToolParser, model_output: str, - request: Union[ChatCompletionRequest, None] = None, + request: ChatCompletionRequest | None = None, streaming: bool = False, assert_one_tool_per_delta: bool = True, -) -> tuple[Union[str, None], list[ToolCall]]: +) -> tuple[str | None, list[ToolCall]]: if streaming: reconstructor = run_tool_extraction_streaming( tool_parser, @@ -105,7 +104,7 @@ def run_tool_extraction( def run_tool_extraction_nonstreaming( tool_parser: ToolParser, model_output: str, - request: Union[ChatCompletionRequest, None] = None, + request: ChatCompletionRequest | None = None, ) -> ExtractedToolCallInformation: request = request or ChatCompletionRequest(messages=[], model="test-model") return tool_parser.extract_tool_calls(model_output, request) @@ -114,7 +113,7 @@ def run_tool_extraction_nonstreaming( def run_tool_extraction_streaming( tool_parser: ToolParser, model_deltas: Iterable[str], - request: Union[ChatCompletionRequest, None] = None, + request: ChatCompletionRequest | None = None, assert_one_tool_per_delta: bool = True, ) -> StreamingToolReconstructor: request = request or ChatCompletionRequest(messages=[], model="test-model") diff --git a/tests/entrypoints/pooling/llm/test_classify.py b/tests/entrypoints/pooling/llm/test_classify.py index ae216c464a5b..96f634ee0a8c 100644 --- a/tests/entrypoints/pooling/llm/test_classify.py +++ b/tests/entrypoints/pooling/llm/test_classify.py @@ -58,10 +58,12 @@ def get_outputs(activation): ) +@pytest.mark.skip_global_cleanup def test_encode_api(llm: LLM): + # chunked prefill does not support all pooling err_msg = "pooling_task must be one of.+" with pytest.raises(ValueError, match=err_msg): - llm.encode(prompts, use_tqdm=False) + llm.encode(prompts, pooling_task="token_classify", use_tqdm=False) def test_score_api(llm: LLM): diff --git a/tests/entrypoints/pooling/llm/test_embedding.py b/tests/entrypoints/pooling/llm/test_embedding.py index aa24a70fd18b..5455b5f91fc0 100644 --- a/tests/entrypoints/pooling/llm/test_embedding.py +++ b/tests/entrypoints/pooling/llm/test_embedding.py @@ -36,6 +36,12 @@ def llm(): @pytest.mark.skip_global_cleanup +def test_encode_api(llm: LLM): + outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False) + multi_vector = outputs[0].outputs.data + assert multi_vector.shape == (11, 384) + + def test_pooling_params(llm: LLM): def get_outputs(normalize): outputs = llm.embed( diff --git a/tests/entrypoints/pooling/llm/test_encode.py b/tests/entrypoints/pooling/llm/test_encode.py index d6aae99944f8..ca85d2758fce 100644 --- a/tests/entrypoints/pooling/llm/test_encode.py +++ b/tests/entrypoints/pooling/llm/test_encode.py @@ -57,24 +57,27 @@ def test_multiple_pooling_params(llm: LLM): ] # Multiple PoolingParams should be matched with each prompt - outputs = llm.encode(PROMPTS, pooling_params=pooling_params) + outputs = llm.encode(PROMPTS, pooling_params=pooling_params, pooling_task="embed") assert len(PROMPTS) == len(outputs) # Exception raised, if the size of params does not match the size of prompts with pytest.raises(ValueError): - outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3]) + outputs = llm.encode( + PROMPTS, pooling_params=pooling_params[:3], pooling_task="embed" + ) # Single PoolingParams should be applied to every prompt single_pooling_params = PoolingParams() - outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params) + outputs = llm.encode( + PROMPTS, pooling_params=single_pooling_params, pooling_task="embed" + ) assert len(PROMPTS) == len(outputs) # pooling_params is None, default params should be applied - outputs = llm.encode(PROMPTS, pooling_params=None) + outputs = llm.encode(PROMPTS, pooling_params=None, pooling_task="embed") assert len(PROMPTS) == len(outputs) -@pytest.mark.skip_global_cleanup def test_right_side_truncation(llm: LLM): # Embeddings models should truncate the end of the prompt tokenizer = llm.get_tokenizer() diff --git a/tests/entrypoints/pooling/llm/test_reward.py b/tests/entrypoints/pooling/llm/test_reward.py index 8312ff180b36..81058dbad891 100644 --- a/tests/entrypoints/pooling/llm/test_reward.py +++ b/tests/entrypoints/pooling/llm/test_reward.py @@ -36,22 +36,23 @@ def llm(): cleanup_dist_env_and_memory() -@pytest.mark.skip_global_cleanup def test_pooling_params(llm: LLM): - def get_outputs(softmax): + def get_outputs(activation): outputs = llm.reward( - prompts, pooling_params=PoolingParams(softmax=softmax), use_tqdm=False + prompts, pooling_params=PoolingParams(activation=activation), use_tqdm=False ) return torch.cat([x.outputs.data for x in outputs]) - default = get_outputs(softmax=None) - w_softmax = get_outputs(softmax=True) - wo_softmax = get_outputs(softmax=False) + default = get_outputs(activation=None) + w_activation = get_outputs(activation=True) + wo_activation = get_outputs(activation=False) - assert torch.allclose(default, w_softmax, atol=1e-2), "Default should use softmax." - assert not torch.allclose(w_softmax, wo_softmax, atol=1e-2), ( - "wo_softmax should not use softmax." + assert torch.allclose(default, w_activation, atol=1e-2), ( + "Default should use activation." ) - assert torch.allclose(softmax(wo_softmax), w_softmax, atol=1e-2), ( - "w_softmax should be close to softmax(wo_softmax)." + assert not torch.allclose(w_activation, wo_activation, atol=1e-2), ( + "wo_activation should not use activation." + ) + assert torch.allclose(softmax(wo_activation), w_activation, atol=1e-2), ( + "w_activation should be close to activation(wo_activation)." ) diff --git a/tests/entrypoints/pooling/llm/test_score.py b/tests/entrypoints/pooling/llm/test_score.py index 9bf74fce906b..2df973dd7863 100644 --- a/tests/entrypoints/pooling/llm/test_score.py +++ b/tests/entrypoints/pooling/llm/test_score.py @@ -33,7 +33,6 @@ def llm(): cleanup_dist_env_and_memory() -@pytest.mark.skip_global_cleanup def test_pooling_params(llm: LLM): def get_outputs(activation): text_1 = "What is the capital of France?" diff --git a/tests/entrypoints/pooling/openai/test_embedding.py b/tests/entrypoints/pooling/openai/test_embedding.py index 6f6559a961a1..b3f12283fdbd 100644 --- a/tests/entrypoints/pooling/openai/test_embedding.py +++ b/tests/entrypoints/pooling/openai/test_embedding.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 +import json import numpy as np import openai @@ -14,8 +15,18 @@ from tests.models.language.pooling.embed_utils import run_embedding_correctness_test from tests.models.utils import check_embeddings_close from tests.utils import RemoteOpenAIServer -from vllm.entrypoints.openai.protocol import EmbeddingResponse +from vllm.entrypoints.openai.protocol import ( + EmbeddingResponse, + PoolingResponse, +) from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.utils.serial_utils import ( + EMBED_DTYPE_TO_TORCH_DTYPE, + ENDIANNESS, + MetadataItem, + binary2tensor, + decode_pooling_output, +) MODEL_NAME = "intfloat/multilingual-e5-small" DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 @@ -244,6 +255,116 @@ async def test_batch_base64_embedding( run_embedding_correctness_test(hf_model, input_texts, default_data) +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_base64_embed_dtype_and_endianness( + server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str +): + input_texts = [ + "The best thing about vLLM is that it supports many different models", + ] + + responses_float = await client.embeddings.create( + input=input_texts, model=model_name, encoding_format="float" + ) + float_data = [d.embedding for d in responses_float.data] + + for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE: + for endianness in ENDIANNESS: + responses_base64 = requests.post( + server.url_for("/v1/embeddings"), + json={ + "model": model_name, + "input": input_texts, + "encoding_format": "base64", + "embed_dtype": embed_dtype, + "endianness": endianness, + }, + ) + + base64_data = [] + for data in responses_base64.json()["data"]: + binary = base64.b64decode(data["embedding"]) + tensor = binary2tensor(binary, (-1,), embed_dtype, endianness) + base64_data.append(tensor.to(torch.float32).tolist()) + + check_embeddings_close( + embeddings_0_lst=float_data, + embeddings_1_lst=base64_data, + name_0="float_data", + name_1="base64_data", + tol=1e-2, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_bytes_embed_dtype_and_endianness( + server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str +): + input_texts = [ + "The best thing about vLLM is that it supports many different models", + ] + + responses_float = await client.embeddings.create( + input=input_texts, model=model_name, encoding_format="float" + ) + float_data = [d.embedding for d in responses_float.data] + + for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()): + for endianness in ENDIANNESS: + responses_bytes = requests.post( + server.url_for("/v1/embeddings"), + json={ + "model": model_name, + "input": input_texts, + "encoding_format": "bytes", + "embed_dtype": embed_dtype, + "endianness": endianness, + }, + ) + + metadata = json.loads(responses_bytes.headers["metadata"]) + body = responses_bytes.content + items = [MetadataItem(**x) for x in metadata["data"]] + + bytes_data = decode_pooling_output(items=items, body=body) + bytes_data = [x.to(torch.float32).tolist() for x in bytes_data] + + check_embeddings_close( + embeddings_0_lst=float_data, + embeddings_1_lst=bytes_data, + name_0="float_data", + name_1="bytes_data", + tol=1e-2, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("param_name", ["encoding_format", "embed_dtype", "endianness"]) +async def test_params_not_supported( + server: RemoteOpenAIServer, model_name: str, param_name: str +): + input_texts = [ + "The best thing about vLLM is that it supports many different models", + ] + + responses_base64 = requests.post( + server.url_for("/v1/embeddings"), + json={ + "model": model_name, + "input": input_texts, + "encoding_format": "base64", + param_name: f"bad_{param_name}", + }, + ) + + assert responses_base64.status_code == 400 + assert "literal_error" in responses_base64.json()["error"]["message"] + assert f"bad_{param_name}" in responses_base64.json()["error"]["message"] + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_single_embedding_truncation(client: openai.AsyncOpenAI, model_name: str): @@ -437,3 +558,20 @@ async def get_outputs(normalize): assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), ( "w_normal should be close to normal(wo_normal)." ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_pooling(server: RemoteOpenAIServer, model_name: str): + input_text = ["The chef prepared a delicious meal."] + + response = requests.post( + server.url_for("pooling"), + json={"model": model_name, "input": input_text, "encoding_format": "float"}, + ) + + poolings = PoolingResponse.model_validate(response.json()) + + assert len(poolings.data) == 1 + assert len(poolings.data[0].data) == 11 + assert len(poolings.data[0].data[0]) == 384 diff --git a/tests/entrypoints/pooling/openai/test_embedding_dimensions.py b/tests/entrypoints/pooling/openai/test_embedding_dimensions.py index 92df43d7dbdc..ba9fb6426277 100644 --- a/tests/entrypoints/pooling/openai/test_embedding_dimensions.py +++ b/tests/entrypoints/pooling/openai/test_embedding_dimensions.py @@ -4,8 +4,6 @@ Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`. """ -from typing import Optional - import openai import pytest @@ -103,14 +101,14 @@ async def make_request_and_correctness_test(dimensions): run_embedding_correctness_test(hf_model, prompts, vllm_outputs, dimensions) if model_info.is_matryoshka: - valid_dimensions: list[Optional[int]] = [None] + valid_dimensions: list[int | None] = [None] if model_info.matryoshka_dimensions is not None: valid_dimensions += model_info.matryoshka_dimensions[:2] for dimensions in valid_dimensions: await make_request_and_correctness_test(dimensions) - invalid_dimensions: list[Optional[int]] = [-1] + invalid_dimensions: list[int | None] = [-1] if model_info.matryoshka_dimensions is not None: assert 5 not in model_info.matryoshka_dimensions invalid_dimensions.append(5) diff --git a/tests/entrypoints/pooling/openai/test_pooling.py b/tests/entrypoints/pooling/openai/test_pooling.py index 3439c556ccc4..4b20c5b0fa84 100644 --- a/tests/entrypoints/pooling/openai/test_pooling.py +++ b/tests/entrypoints/pooling/openai/test_pooling.py @@ -2,15 +2,24 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 +import json import numpy as np import pytest import requests +import torch from tests.models.utils import check_embeddings_close from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import PoolingResponse from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.utils.serial_utils import ( + EMBED_DTYPE_TO_TORCH_DTYPE, + ENDIANNESS, + MetadataItem, + binary2tensor, + decode_pooling_output, +) MODEL_NAME = "internlm/internlm2-1_8b-reward" DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 @@ -248,6 +257,130 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, model_name: str) ) +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_base64_embed_dtype_and_endianness( + server: RemoteOpenAIServer, model_name: str +): + input_texts = [ + "The best thing about vLLM is that it supports many different models", + ] + + url = server.url_for("pooling") + float_response = requests.post( + url, + json={ + "model": model_name, + "input": input_texts, + "encoding_format": "float", + }, + ) + responses_float = PoolingResponse.model_validate(float_response.json()) + float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data] + + for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE: + for endianness in ENDIANNESS: + responses_base64 = requests.post( + url, + json={ + "model": model_name, + "input": input_texts, + "encoding_format": "base64", + "embed_dtype": embed_dtype, + "endianness": endianness, + }, + ) + + base64_data = [] + for data in responses_base64.json()["data"]: + binary = base64.b64decode(data["data"]) + tensor = binary2tensor(binary, (-1,), embed_dtype, endianness) + base64_data.append(tensor.to(torch.float32).tolist()) + + check_embeddings_close( + embeddings_0_lst=float_data, + embeddings_1_lst=base64_data, + name_0="float_data", + name_1="base64_data", + tol=1e-2, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_bytes_embed_dtype_and_endianness( + server: RemoteOpenAIServer, model_name: str +): + input_texts = [ + "The best thing about vLLM is that it supports many different models", + ] + + url = server.url_for("pooling") + float_response = requests.post( + url, + json={ + "model": model_name, + "input": input_texts, + "encoding_format": "float", + }, + ) + responses_float = PoolingResponse.model_validate(float_response.json()) + float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data] + + for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()): + for endianness in ENDIANNESS: + responses_bytes = requests.post( + url, + json={ + "model": model_name, + "input": input_texts, + "encoding_format": "bytes", + "embed_dtype": embed_dtype, + "endianness": endianness, + }, + ) + + metadata = json.loads(responses_bytes.headers["metadata"]) + body = responses_bytes.content + items = [MetadataItem(**x) for x in metadata["data"]] + + bytes_data = decode_pooling_output(items=items, body=body) + bytes_data = [x.to(torch.float32).view(-1).tolist() for x in bytes_data] + + check_embeddings_close( + embeddings_0_lst=float_data, + embeddings_1_lst=bytes_data, + name_0="float_data", + name_1="bytes_data", + tol=1e-2, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("param_name", ["encoding_format", "embed_dtype", "endianness"]) +async def test_params_not_supported( + server: RemoteOpenAIServer, model_name: str, param_name: str +): + input_texts = [ + "The best thing about vLLM is that it supports many different models", + ] + + responses_base64 = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "input": input_texts, + "encoding_format": "base64", + param_name: f"bad_{param_name}", + }, + ) + + assert responses_base64.status_code == 400 + assert "literal_error" in responses_base64.json()["error"]["message"] + assert f"bad_{param_name}" in responses_base64.json()["error"]["message"] + + @pytest.mark.asyncio async def test_invocations(server: RemoteOpenAIServer): input_texts = [ diff --git a/tests/entrypoints/pooling/openai/test_rerank.py b/tests/entrypoints/pooling/openai/test_rerank.py index 9980fcff16c1..e43148d25fee 100644 --- a/tests/entrypoints/pooling/openai/test_rerank.py +++ b/tests/entrypoints/pooling/openai/test_rerank.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from tests.utils import RemoteOpenAIServer -from vllm.entrypoints.openai.protocol import RerankResponse +from vllm.entrypoints.openai.protocol import PoolingResponse, RerankResponse MODEL_NAME = "BAAI/bge-reranker-base" DTYPE = "bfloat16" @@ -159,3 +159,20 @@ async def get_outputs(activation): assert torch.allclose(F.sigmoid(wo_activation), w_activation, atol=1e-2), ( "w_activation should be close to activation(wo_activation)." ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_pooling(server: RemoteOpenAIServer, model_name: str): + input_text = ["The chef prepared a delicious meal."] + + response = requests.post( + server.url_for("pooling"), + json={"model": model_name, "input": input_text, "encoding_format": "float"}, + ) + + poolings = PoolingResponse.model_validate(response.json()) + + assert len(poolings.data) == 1 + assert len(poolings.data[0].data) == 11 + assert len(poolings.data[0].data[0]) == 1 diff --git a/tests/entrypoints/test_api_server_process_manager.py b/tests/entrypoints/test_api_server_process_manager.py index e548f52e1e94..3fadbf2ef0dd 100644 --- a/tests/entrypoints/test_api_server_process_manager.py +++ b/tests/entrypoints/test_api_server_process_manager.py @@ -5,7 +5,6 @@ import socket import threading import time -from typing import Optional from unittest.mock import patch import pytest @@ -105,7 +104,7 @@ def test_wait_for_completion_or_failure(api_server_args): assert len(manager.processes) == 3 # Create a result capture for the thread - result: dict[str, Optional[Exception]] = {"exception": None} + result: dict[str, Exception | None] = {"exception": None} def run_with_exception_capture(): try: @@ -218,7 +217,7 @@ def close(self): assert len(manager.processes) == 3 # Create a result capture for the thread - result: dict[str, Optional[Exception]] = {"exception": None} + result: dict[str, Exception | None] = {"exception": None} def run_with_exception_capture(): try: diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 6e92419c4f67..378c2624f7d9 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -3,11 +3,10 @@ import warnings from collections.abc import Mapping -from typing import Literal, Optional +from typing import Literal import pytest -from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, SpecialTokens -from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer +from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset @@ -74,6 +73,19 @@ def phi3v_model_config_mm_interleaved(): ) +@pytest.fixture(scope="function") +def phi3v_model_config_image_embeds(): + return ModelConfig( + PHI3V_MODEL_ID, + runner="generate", + trust_remote_code=True, + limit_mm_per_prompt={ + "image": 2, + }, + enable_mm_embeds=True, + ) + + @pytest.fixture(scope="module") def phi3v_tokenizer(): return get_tokenizer(PHI3V_MODEL_ID) @@ -153,9 +165,9 @@ def audio_url(): def _assert_mm_data_is_image_input( - mm_data: Optional[MultiModalDataDict], + mm_data: MultiModalDataDict | None, image_count: int, - skipped_image_indices: Optional[list] = None, + skipped_image_indices: list | None = None, ) -> None: assert mm_data is not None assert set(mm_data.keys()) == {"image"} @@ -170,9 +182,9 @@ def _assert_mm_data_is_image_input( def _assert_mm_uuids( - mm_uuids: Optional[MultiModalUUIDDict], + mm_uuids: MultiModalUUIDDict | None, media_count: int, - expected_uuids: list[Optional[str]], + expected_uuids: list[str | None], modality: str = "image", ) -> None: if len(expected_uuids) > 0: @@ -194,9 +206,9 @@ def _assert_mm_uuids( def _assert_mm_data_inputs( - mm_data: Optional[MultiModalDataDict], + mm_data: MultiModalDataDict | None, data_count: MultiModalDataCounts, - skipped_media_indices: Optional[dict[str, list]] = None, # modality -> list[int] + skipped_media_indices: dict[str, list] | None = None, # modality -> list[int] ) -> None: assert mm_data is not None assert set(data_count.keys()) == (set(mm_data.keys())) @@ -800,7 +812,7 @@ def test_parse_chat_messages_empty_pil_image_with_uuid( def test_parse_chat_messages_empty_image_embeds_with_uuid( - phi3v_model_config, + phi3v_model_config_image_embeds, phi3v_tokenizer, ): uuid = "abcd" @@ -814,7 +826,7 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid( ], } ], - phi3v_model_config, + phi3v_model_config_image_embeds, phi3v_tokenizer, content_format="string", ) @@ -833,7 +845,7 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid( @pytest.mark.asyncio async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( - phi3v_model_config, + phi3v_model_config_image_embeds, phi3v_tokenizer, ): uuid = "abcd" @@ -847,7 +859,7 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( ], } ], - phi3v_model_config, + phi3v_model_config_image_embeds, phi3v_tokenizer, content_format="string", ) @@ -1730,7 +1742,9 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, - skip_tokenizer_init=model_info.skip_tokenizer_init, + skip_tokenizer_init=model_info.require_embed_inputs, + enable_prompt_embeds=model_info.require_embed_inputs, + enable_mm_embeds=model_info.require_embed_inputs, enforce_eager=model_info.enforce_eager, dtype=model_info.dtype, ) @@ -1811,6 +1825,7 @@ def test_resolve_hf_chat_template_kwargs(sample_json_schema, model, expected_kwa "unsed_kwargs_2": "abc", # should not appear "chat_template": "{% Hello world! %}", + "tokenize": True, # used by tokenizer "continue_final_message": True, "tools": tools, @@ -1829,7 +1844,9 @@ def test_resolve_hf_chat_template_kwargs(sample_json_schema, model, expected_kwa revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, - skip_tokenizer_init=model_info.skip_tokenizer_init, + skip_tokenizer_init=model_info.require_embed_inputs, + enable_prompt_embeds=model_info.require_embed_inputs, + enable_mm_embeds=model_info.require_embed_inputs, enforce_eager=model_info.enforce_eager, dtype=model_info.dtype, ) @@ -1847,10 +1864,21 @@ def test_resolve_hf_chat_template_kwargs(sample_json_schema, model, expected_kwa tools=tools, model_config=model_config, ) + with pytest.raises( + ValueError, match="Found unexpected chat template kwargs from request" + ): + # should raise error if `chat_template_kwargs` contains + # `chat_template` or `tokenize` + resolve_chat_template_kwargs( + tokenizer, + chat_template=chat_template, + chat_template_kwargs=chat_template_kwargs, + ) resolved_chat_template_kwargs = resolve_chat_template_kwargs( tokenizer, chat_template=chat_template, chat_template_kwargs=chat_template_kwargs, + raise_on_unexpected=False, ) assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs @@ -1879,7 +1907,9 @@ def test_resolve_content_format_hf_defined(model, expected_format): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, - skip_tokenizer_init=model_info.skip_tokenizer_init, + skip_tokenizer_init=model_info.require_embed_inputs, + enable_prompt_embeds=model_info.require_embed_inputs, + enable_mm_embeds=model_info.require_embed_inputs, enforce_eager=model_info.enforce_eager, dtype=model_info.dtype, ) @@ -1937,7 +1967,9 @@ def test_resolve_content_format_fallbacks(model, expected_format): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, - skip_tokenizer_init=model_info.skip_tokenizer_init, + skip_tokenizer_init=model_info.require_embed_inputs, + enable_prompt_embeds=model_info.require_embed_inputs, + enable_mm_embeds=model_info.require_embed_inputs, enforce_eager=model_info.enforce_eager, dtype=model_info.dtype, ) @@ -2119,34 +2151,9 @@ def test_apply_mistral_chat_template_thinking_chunk(): }, {"role": "user", "content": "Thanks, what is 3+3?"}, ] - - # TODO(Julien): upon model release change to a tokenizer already configured. - # ================================================================= mistral_tokenizer = MistralTokenizer.from_pretrained( - "mistralai/Devstral-Small-2507" - ) - assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer) - # Add think special tokens to the tokenizer - mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo( - rank=35, is_control=True, token_str=SpecialTokens.begin_think.value + "mistralai/Magistral-Small-2509" ) - mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo( - rank=36, is_control=True, token_str=SpecialTokens.end_think.value - ) - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = { - k: v - for k, v in mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items() - if v not in {35, 36} - } - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ - SpecialTokens.begin_think.value - ] = 35 - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ - SpecialTokens.end_think.value - ] = 36 - mistral_tokenizer.instruct.BEGIN_THINK = 35 - mistral_tokenizer.instruct.END_THINK = 36 - # ================================================================= tokens_ids = apply_mistral_chat_template( mistral_tokenizer, messages, chat_template=None, tools=None diff --git a/tests/entrypoints/test_context.py b/tests/entrypoints/test_context.py index 6ad18fa08bc4..31ea856224f9 100644 --- a/tests/entrypoints/test_context.py +++ b/tests/entrypoints/test_context.py @@ -6,16 +6,14 @@ import pytest from openai_harmony import Author, Message, Role, StreamState, TextContent -from vllm.entrypoints.context import HarmonyContext, StreamingHarmonyContext +from vllm.entrypoints.context import ( + HarmonyContext, + StreamingHarmonyContext, + TurnMetrics, +) from vllm.outputs import CompletionOutput, RequestOutput -# Helper function for Python < 3.10 compatibility -async def async_next(async_iterator): - """Compatibility function equivalent to Python 3.10's anext().""" - return await async_iterator.__anext__() - - def create_mock_request_output( prompt_token_ids=None, output_token_ids=None, @@ -107,8 +105,12 @@ def test_single_turn_token_counting(): # Verify internal state tracking assert not context.is_first_turn - assert context.previous_turn.input_tokens == 5 - assert context.previous_turn.output_tokens == 3 + assert len(context.all_turn_metrics) == 1 + previous_turn = context.all_turn_metrics[0] + assert previous_turn.input_tokens == 5 + assert previous_turn.output_tokens == 3 + assert previous_turn.cached_input_tokens == 2 + assert previous_turn.tool_output_tokens == 0 @pytest.mark.asyncio @@ -129,7 +131,7 @@ async def test_multi_turn_token_counting(): ) # First turn - initial prompt and response - mock_output1 = await async_next(mock_generator) + mock_output1 = await anext(mock_generator) context.append_output(mock_output1) # At this point, we should have 5 prompt tokens and 3 output tokens @@ -138,7 +140,7 @@ async def test_multi_turn_token_counting(): assert context.num_tool_output_tokens == 0 # Second turn - after tool output - mock_output2 = await async_next(mock_generator) + mock_output2 = await anext(mock_generator) context.append_output(mock_output2) # Current prompt tokens (15) - last_turn_input_tokens (5) - # last_turn_output_tokens (3) = 7 @@ -150,7 +152,7 @@ async def test_multi_turn_token_counting(): assert context.num_cached_tokens == 5 # Third turn - final response - mock_output3 = await async_next(mock_generator) + mock_output3 = await anext(mock_generator) context.append_output(mock_output3) # Additional tool output tokens from third turn: # Current prompt (20) - last_turn_input_tokens (15) - @@ -162,6 +164,15 @@ async def test_multi_turn_token_counting(): assert context.num_tool_output_tokens == expected_tool_output assert context.num_cached_tokens == 5 + 15 + # Validate all turn metrics + assert len(context.all_turn_metrics) == 3 + for i, turn in enumerate(context.all_turn_metrics): + assert turn.input_tokens == prompt_token_counts[i] + assert turn.output_tokens == output_token_counts[i] + assert turn.cached_input_tokens == cached_token_counts[i] + assert context.all_turn_metrics[1].tool_output_tokens == 7 + assert context.all_turn_metrics[2].tool_output_tokens == 1 + def test_empty_output_tokens(): """Test behavior when RequestOutput has empty output tokens.""" @@ -320,6 +331,10 @@ async def test_streaming_multi_turn_token_counting(mock_parser): # Create a streaming context context = StreamingHarmonyContext(messages=[], available_tools=["browser"]) + num_prompt_tokens = [3, 8, 13] + num_output_tokens = [3, 3, 2] + num_cached_tokens = [0, 3, 8] + # Simulate three turns of conversation: # Turn 1: stream tokens one by one, then finish the message # Turn 2: new prompt, stream more tokens with a reasoning segment @@ -331,7 +346,7 @@ async def test_streaming_multi_turn_token_counting(mock_parser): create_mock_request_output( prompt_token_ids=[1, 2, 3], # 3 prompt tokens output_token_ids=[101], # Single token - num_cached_tokens=0, + num_cached_tokens=num_cached_tokens[0], finished=False, # Not end of message yet ) ) @@ -376,7 +391,7 @@ async def test_streaming_multi_turn_token_counting(mock_parser): 5, ], # 8 tokens (includes previous) output_token_ids=[201], - num_cached_tokens=3, # Some tokens cached + num_cached_tokens=num_cached_tokens[1], # Some tokens cached finished=False, ) ) @@ -428,7 +443,7 @@ async def test_streaming_multi_turn_token_counting(mock_parser): 7, ], # 13 tokens output_token_ids=[301], - num_cached_tokens=8, # More cached tokens + num_cached_tokens=num_cached_tokens[2], # More cached tokens finished=False, ) ) @@ -441,10 +456,12 @@ async def test_streaming_multi_turn_token_counting(mock_parser): ) # Final token counts check - assert context.num_prompt_tokens == 3 + 8 + 13 # All prompts - assert context.num_output_tokens == 3 + 3 + 2 # All outputs + assert context.num_prompt_tokens == sum(num_prompt_tokens) # All prompts + assert context.num_output_tokens == sum(num_output_tokens) # All outputs assert context.num_reasoning_tokens == 3 # Unchanged from second turn - assert context.num_cached_tokens == 3 + 8 # Accumulated cached tokens + assert context.num_cached_tokens == sum( + num_cached_tokens + ) # Accumulated cached tokens # Additional tool tokens from third turn # Formula: this turn prompt - last turn prompt - last turn output @@ -453,6 +470,15 @@ async def test_streaming_multi_turn_token_counting(mock_parser): context.num_tool_output_tokens == expected_tool_tokens + additional_tool_tokens ) + # Validate all turn metrics + assert len(context.all_turn_metrics) == 3 + for i, turn in enumerate(context.all_turn_metrics): + assert turn.input_tokens == num_prompt_tokens[i] + assert turn.output_tokens == num_output_tokens[i] + assert turn.cached_input_tokens == num_cached_tokens[i] + assert context.all_turn_metrics[1].tool_output_tokens == 2 + assert context.all_turn_metrics[2].tool_output_tokens == 2 + @pytest.mark.asyncio async def test_streaming_message_synchronization(mock_parser): @@ -528,3 +554,46 @@ async def test_streaming_message_synchronization(mock_parser): assert len(context._messages) == 3 assert context.num_init_messages == 1 assert context._messages[2].content[0].text == "Response 4" + + +def test_turn_metrics_copy_and_reset(): + """Test TurnMetrics copy and reset methods work correctly.""" + # Create a TurnMetrics with specific values + original_metrics = TurnMetrics( + input_tokens=10, + output_tokens=20, + cached_input_tokens=5, + tool_output_tokens=3, + ) + + # Test copy functionality + copied_metrics = original_metrics.copy() + + # Verify copy has same values + assert copied_metrics.input_tokens == 10 + assert copied_metrics.output_tokens == 20 + assert copied_metrics.cached_input_tokens == 5 + assert copied_metrics.tool_output_tokens == 3 + + # Verify they are separate objects + assert copied_metrics is not original_metrics + + # Modify copy to ensure independence + copied_metrics.input_tokens = 999 + assert original_metrics.input_tokens == 10 # Original unchanged + assert copied_metrics.input_tokens == 999 + + # Test reset functionality + original_metrics.reset() + + # Verify all fields are reset to zero + assert original_metrics.input_tokens == 0 + assert original_metrics.output_tokens == 0 + assert original_metrics.cached_input_tokens == 0 + assert original_metrics.tool_output_tokens == 0 + + # Verify copied metrics are unaffected by reset + assert copied_metrics.input_tokens == 999 + assert copied_metrics.output_tokens == 20 + assert copied_metrics.cached_input_tokens == 5 + assert copied_metrics.tool_output_tokens == 3 diff --git a/tests/entrypoints/test_renderer.py b/tests/entrypoints/test_renderer.py index f93978c3e6e7..b0ef3dd045bd 100644 --- a/tests/entrypoints/test_renderer.py +++ b/tests/entrypoints/test_renderer.py @@ -3,7 +3,6 @@ import io from dataclasses import dataclass -from typing import Optional from unittest.mock import AsyncMock, MagicMock import pybase64 @@ -17,7 +16,8 @@ @dataclass class MockModelConfig: max_model_len: int = 100 - encoder_config: Optional[dict] = None + encoder_config: dict | None = None + enable_prompt_embeds: bool = True class MockTokenizerResult: diff --git a/tests/evals/gsm8k/gsm8k_eval.py b/tests/evals/gsm8k/gsm8k_eval.py index 9edec7a78ca2..c7799607912b 100644 --- a/tests/evals/gsm8k/gsm8k_eval.py +++ b/tests/evals/gsm8k/gsm8k_eval.py @@ -12,7 +12,6 @@ import os import time from collections.abc import Generator -from typing import Optional, Union import aiohttp import numpy as np @@ -23,7 +22,7 @@ INVALID = -9999999 -def download_and_cache_file(url: str, filename: Optional[str] = None) -> str: +def download_and_cache_file(url: str, filename: str | None = None) -> str: """Download and cache a file from a URL.""" if filename is None: filename = os.path.join("/tmp", url.split("/")[-1]) @@ -81,9 +80,9 @@ async def call_vllm_api( prompt: str, temperature: float, max_tokens: int, - stop: Optional[list[str]] = None, - url: Optional[str] = None, - seed: Optional[int] = None, + stop: list[str] | None = None, + url: str | None = None, + seed: int | None = None, ) -> str: """Call vLLM's OpenAI-compatible completions endpoint.""" data = { @@ -112,8 +111,8 @@ def evaluate_gsm8k( host: str = "http://127.0.0.1", port: int = 8000, temperature: float = 0.0, - seed: Optional[int] = 42, -) -> dict[str, Union[float, int]]: + seed: int | None = 42, +) -> dict[str, float | int]: """ Evaluate GSM8K accuracy using vLLM serve endpoint. diff --git a/tests/kernels/attention/conftest.py b/tests/kernels/attention/conftest.py index b080a71bd54e..e520267320c0 100644 --- a/tests/kernels/attention/conftest.py +++ b/tests/kernels/attention/conftest.py @@ -3,7 +3,10 @@ import pytest -from vllm.utils import create_kv_caches_with_random, create_kv_caches_with_random_flash +from vllm.utils.torch_utils import ( + create_kv_caches_with_random, + create_kv_caches_with_random_flash, +) @pytest.fixture() diff --git a/tests/kernels/attention/test_aiter_flash_attn.py b/tests/kernels/attention/test_aiter_flash_attn.py index 88b21a9b84d6..1dec46e33f22 100644 --- a/tests/kernels/attention/test_aiter_flash_attn.py +++ b/tests/kernels/attention/test_aiter_flash_attn.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import pytest import torch @@ -27,8 +26,8 @@ def ref_paged_attn( kv_lens: list[int], block_tables: torch.Tensor, scale: float, - sliding_window: Optional[int] = None, - soft_cap: Optional[float] = None, + sliding_window: int | None = None, + soft_cap: float | None = None, ) -> torch.Tensor: num_seqs = len(query_lens) block_tables = block_tables.cpu().numpy() @@ -94,12 +93,12 @@ def test_varlen_with_paged_kv( seq_lens: list[tuple[int, int]], num_heads: tuple[int, int], head_size: int, - sliding_window: Optional[int], + sliding_window: int | None, dtype: torch.dtype, block_size: int, - soft_cap: Optional[float], + soft_cap: float | None, num_blocks: int, - q_dtype: Optional[torch.dtype], + q_dtype: torch.dtype | None, ) -> None: torch.set_default_device("cuda") current_platform.seed_everything(0) diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index 16e544eb3cf9..9662e73321eb 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random -from typing import Optional import pytest import torch @@ -12,7 +11,7 @@ from vllm import _custom_ops as ops from vllm.attention.layer import Attention, MultiHeadAttention from vllm.platforms import current_platform -from vllm.utils import get_max_shared_memory_bytes +from vllm.utils.mem_utils import get_max_shared_memory_bytes if not current_platform.is_rocm(): from xformers import ops as xops @@ -50,7 +49,7 @@ def ref_masked_attention( key: torch.Tensor, value: torch.Tensor, scale: float, - attn_mask: Optional[torch.Tensor] = None, + attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() if attn_mask is not None: @@ -69,7 +68,7 @@ def ref_single_query_cached_kv_attention( block_tables: torch.Tensor, seq_lens: torch.Tensor, scale: float, - alibi_slopes: Optional[torch.Tensor], + alibi_slopes: torch.Tensor | None, ) -> None: num_query_heads = query.shape[1] num_kv_heads = value_cache.shape[1] @@ -415,7 +414,7 @@ def ref_multi_query_kv_attention( key: torch.Tensor, value: torch.Tensor, scale: float, - alibi_bias: Optional[list[torch.Tensor]], + alibi_bias: list[torch.Tensor] | None, dtype: torch.dtype, ) -> torch.Tensor: num_seqs = len(cu_seq_lens) - 1 diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index fa95c3b2d39e..48a42ce6ffab 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -34,7 +34,7 @@ def clear_cache(): DEVICE_REGULAR_ATTN_BACKENDS = { "cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"], - "hip": ["ROCM_FLASH"], + "hip": ["ROCM_ATTN"], "cpu": ["TORCH_SDPA"], } @@ -84,12 +84,12 @@ def test_env( m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") if device == "cpu": - with patch("vllm.attention.selector.current_platform", CpuPlatform()): + with patch("vllm.platforms.current_platform", CpuPlatform()): backend = get_attn_backend(16, torch.float16, None, block_size) assert backend.get_name() == "TORCH_SDPA" elif device == "hip": - with patch("vllm.attention.selector.current_platform", RocmPlatform()): + with patch("vllm.platforms.current_platform", RocmPlatform()): if use_mla: # ROCm MLA backend logic: # - TRITON_MLA: supported when block_size != 1 @@ -122,11 +122,11 @@ def test_env( backend = get_attn_backend( 16, torch.float16, None, block_size, use_mla=use_mla ) - expected = "TRITON_ATTN" + expected = "ROCM_ATTN" assert backend.get_name() == expected elif device == "cuda": - with patch("vllm.attention.selector.current_platform", CudaPlatform()): + with patch("vllm.platforms.current_platform", CudaPlatform()): if use_mla: # CUDA MLA backend logic: # - CUTLASS_MLA: only supported with block_size == 128 @@ -214,12 +214,12 @@ def test_env( def test_fp32_fallback(device: str): """Test attention backend selection with fp32.""" if device == "cpu": - with patch("vllm.attention.selector.current_platform", CpuPlatform()): + with patch("vllm.platforms.current_platform", CpuPlatform()): backend = get_attn_backend(16, torch.float32, None, 16) assert backend.get_name() == "TORCH_SDPA" elif device == "cuda": - with patch("vllm.attention.selector.current_platform", CudaPlatform()): + with patch("vllm.platforms.current_platform", CudaPlatform()): backend = get_attn_backend(16, torch.float32, None, 16) assert backend.get_name() == "FLEX_ATTENTION" @@ -277,7 +277,7 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch): """Test that invalid attention backend names raise ValueError.""" with ( monkeypatch.context() as m, - patch("vllm.attention.selector.current_platform", CudaPlatform()), + patch("vllm.platforms.current_platform", CudaPlatform()), ): m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) diff --git a/tests/kernels/attention/test_cascade_flash_attn.py b/tests/kernels/attention/test_cascade_flash_attn.py index 58e8bd592ba4..4295f852f95b 100755 --- a/tests/kernels/attention/test_cascade_flash_attn.py +++ b/tests/kernels/attention/test_cascade_flash_attn.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import pytest import torch @@ -85,7 +84,7 @@ def test_cascade( head_size: int, dtype: torch.dtype, block_size: int, - soft_cap: Optional[float], + soft_cap: float | None, num_blocks: int, fa_version: int, ) -> None: diff --git a/tests/kernels/attention/test_cutlass_mla_decode.py b/tests/kernels/attention/test_cutlass_mla_decode.py index dad1510ce532..a60f4e385a89 100644 --- a/tests/kernels/attention/test_cutlass_mla_decode.py +++ b/tests/kernels/attention/test_cutlass_mla_decode.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math import random -from typing import Optional import pytest import torch @@ -17,7 +16,7 @@ def cal_diff( y: torch.Tensor, name: str, use_fp8: bool = False, - diff_threshold: Optional[float] = None, + diff_threshold: float | None = None, ) -> None: x, y = x.double(), y.double() cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py index 4873afa649c9..74a5d8117962 100644 --- a/tests/kernels/attention/test_deepgemm_attention.py +++ b/tests/kernels/attention/test_deepgemm_attention.py @@ -6,7 +6,7 @@ import torch from vllm.platforms import current_platform -from vllm.utils import cdiv, has_deep_gemm +from vllm.utils import cdiv from vllm.utils.deep_gemm import ( _ceil_to_ue8m0, calc_diff, @@ -15,6 +15,7 @@ get_num_sms, get_paged_mqa_logits_metadata, ) +from vllm.utils.import_utils import has_deep_gemm def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: @@ -82,8 +83,7 @@ def _ref_fp8_mqa_logits( torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] ) mask = mask_lo & mask_hi - - score = torch.einsum("mhd,and->hmn", q, k) + score = torch.einsum("mhd,nd->hmn", q, k) logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) logits = logits.masked_fill(~mask, float("-inf")) diff --git a/tests/kernels/attention/test_flash_attn.py b/tests/kernels/attention/test_flash_attn.py index d39f0a593ed4..18995545552e 100644 --- a/tests/kernels/attention/test_flash_attn.py +++ b/tests/kernels/attention/test_flash_attn.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import pytest import torch @@ -34,8 +33,8 @@ def ref_paged_attn( kv_lens: list[int], block_tables: torch.Tensor, scale: float, - sliding_window: Optional[int] = None, - soft_cap: Optional[float] = None, + sliding_window: int | None = None, + soft_cap: float | None = None, ) -> torch.Tensor: num_seqs = len(query_lens) block_tables = block_tables.cpu().numpy() @@ -103,11 +102,11 @@ def test_flash_attn_with_paged_kv( head_size: int, dtype: torch.dtype, block_size: int, - soft_cap: Optional[float], + soft_cap: float | None, num_blocks: int, - sliding_window: Optional[int], + sliding_window: int | None, fa_version: int, - q_dtype: Optional[torch.dtype], + q_dtype: torch.dtype | None, ) -> None: torch.set_default_device("cuda") if not is_fa_version_supported(fa_version): @@ -221,13 +220,13 @@ def test_varlen_with_paged_kv( seq_lens: list[tuple[int, int]], num_heads: tuple[int, int], head_size: int, - sliding_window: Optional[int], + sliding_window: int | None, dtype: torch.dtype, block_size: int, - soft_cap: Optional[float], + soft_cap: float | None, num_blocks: int, fa_version: int, - q_dtype: Optional[torch.dtype], + q_dtype: torch.dtype | None, ) -> None: torch.set_default_device("cuda") if not is_fa_version_supported(fa_version): diff --git a/tests/kernels/attention/test_flashinfer.py b/tests/kernels/attention/test_flashinfer.py index 52cd10fdc5be..82ec2ef14e56 100644 --- a/tests/kernels/attention/test_flashinfer.py +++ b/tests/kernels/attention/test_flashinfer.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import flashinfer import pytest @@ -26,8 +25,8 @@ def ref_paged_attn( kv_lens: list[int], block_tables: torch.Tensor, scale: float, - sliding_window: Optional[int] = None, - soft_cap: Optional[float] = None, + sliding_window: int | None = None, + soft_cap: float | None = None, ) -> torch.Tensor: num_seqs = len(query_lens) block_tables = block_tables.cpu().numpy() @@ -90,8 +89,8 @@ def test_flashinfer_decode_with_paged_kv( head_size: int, dtype: torch.dtype, block_size: int, - soft_cap: Optional[float], - sliding_window: Optional[int], + soft_cap: float | None, + sliding_window: int | None, ) -> None: torch.set_default_device("cuda") current_platform.seed_everything(0) @@ -185,8 +184,8 @@ def test_flashinfer_prefill_with_paged_kv( head_size: int, dtype: torch.dtype, block_size: int, - soft_cap: Optional[float], - sliding_window: Optional[int], + soft_cap: float | None, + sliding_window: int | None, ) -> None: torch.set_default_device("cuda") current_platform.seed_everything(0) @@ -288,7 +287,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv( head_size: int, dtype: torch.dtype, block_size: int, - soft_cap: Optional[float], + soft_cap: float | None, ) -> None: pytest.skip("TODO: fix the accuracy issue") torch.set_default_device("cuda") @@ -398,7 +397,7 @@ def test_flashinfer_decode_with_paged_fp8_kv( head_size: int, dtype: torch.dtype, block_size: int, - soft_cap: Optional[float], + soft_cap: float | None, ) -> None: # test doesn't work for num_heads = (16,16) torch.set_default_device("cuda") diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 62d94f0bb751..00f06da5a47b 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -1,15 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import flashinfer import pytest import torch from tests.kernels.quantization.nvfp4_utils import ( - FLOAT4_E2M1_MAX, - FLOAT8_E4M3_MAX, dequantize_nvfp4_to_dtype, + get_nvfp4_global_scale, ) from vllm.platforms import current_platform from vllm.utils import round_up @@ -50,6 +48,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn): BLOCK_SIZE = [16] WINDOW_LEFT = [-1, 127] SOFT_CAP = [None, 50.0] +HAS_SINKS = [True, False] NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. @@ -64,12 +63,11 @@ def to_float8(x, dtype=torch.float8_e4m3fn): @pytest.mark.parametrize("block_size", BLOCK_SIZE) @pytest.mark.parametrize("window_left", WINDOW_LEFT) @pytest.mark.parametrize("soft_cap", SOFT_CAP) +@pytest.mark.parametrize("has_sinks", HAS_SINKS) @torch.inference_mode def test_flashinfer_trtllm_decode_with_baseline( dtype: torch.dtype, - quant_dtypes: tuple[ - Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] - ], + quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None], batch_size: int, max_seq_lens: tuple[int, int], num_heads: tuple[int, int], @@ -77,10 +75,11 @@ def test_flashinfer_trtllm_decode_with_baseline( kv_layout: str, block_size: int, window_left: int, - soft_cap: Optional[float], + soft_cap: float | None, + has_sinks: bool, ) -> None: torch.set_default_device("cuda") - current_platform.seed_everything(0) + current_platform.seed_everything(42) q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes q_quant_dtype = q_quant_dtype or dtype @@ -102,7 +101,16 @@ def test_flashinfer_trtllm_decode_with_baseline( else: raise ValueError(f"Invalid kv_layout: {kv_layout}") - query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) + # max_q_len = 1 + q_lens = torch.ones((batch_size,), dtype=torch.int32) + q_indptr = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), + ] + ) + + query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype) if q_quant_dtype == FP8_DTYPE: query, q_scale = to_float8(query) ref_query = query.to(dtype) * q_scale @@ -113,7 +121,7 @@ def test_flashinfer_trtllm_decode_with_baseline( kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32) kv_lens[-1] = max_kv_len - seq_lens = kv_lens + seq_lens = kv_lens + q_lens max_seq_len = torch.max(seq_lens).item() kv_cache = torch.randn(kv_cache_shape, dtype=dtype) @@ -149,35 +157,43 @@ def test_flashinfer_trtllm_decode_with_baseline( workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) # Baseline Decode - wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, kv_layout, use_tensor_cores=True - ) + if has_sinks: + sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5 + wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper( + float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2" + ) + else: + sinks = None + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2" + ) + wrapper.plan( - kv_indptr, - kv_indices, - kv_last_page_lens, - num_qo_heads, - num_kv_heads, - head_size, - block_size, - "NONE", + qo_indptr=q_indptr, + paged_kv_indptr=kv_indptr, + paged_kv_indices=kv_indices, + paged_kv_last_page_len=kv_last_page_lens, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_size, + page_size=block_size, + causal=True, sm_scale=sm_scale, - q_data_type=dtype, - kv_data_type=dtype, window_left=window_left, logits_soft_cap=soft_cap, + q_data_type=dtype, + kv_data_type=dtype, ) - output = torch.empty(ref_query.shape, dtype=dtype) - wrapper.run(ref_query, ref_kv_cache, out=output) + wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output) + o_scale = 1.0 - o_sf_scale = None + o_sf_scale_float = None if o_quant_dtype == FP8_DTYPE: _, o_scale = to_float8(output) elif o_quant_dtype == FP4_DTYPE: - o_sf_scale = ( - (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1) - ).to(torch.float32) + o_sf_scale = get_nvfp4_global_scale(output) + o_sf_scale_float = o_sf_scale.item() # TRTLLM Decode if o_quant_dtype == FP4_DTYPE: @@ -204,7 +220,8 @@ def test_flashinfer_trtllm_decode_with_baseline( bmm1_scale=q_scale * k_scale * sm_scale, bmm2_scale=v_scale / o_scale, window_left=window_left, - o_sf_scale=o_sf_scale, + sinks=sinks, + o_sf_scale=o_sf_scale_float, out=output_trtllm, ) if o_quant_dtype == FP8_DTYPE: @@ -219,11 +236,13 @@ def test_flashinfer_trtllm_decode_with_baseline( output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2]) if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: - rtol, atol = 3e-1, 1e0 + rtol, atol = 7e-2, 9e-2 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: - rtol, atol = 5e-2, 7e-2 - else: + rtol, atol = 2e-2, 4e-2 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype: rtol, atol = 1e-2, 2e-2 + else: + rtol, atol = 1e-2, 1e-2 ( torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), @@ -241,12 +260,11 @@ def test_flashinfer_trtllm_decode_with_baseline( @pytest.mark.parametrize("block_size", BLOCK_SIZE) @pytest.mark.parametrize("window_left", WINDOW_LEFT) @pytest.mark.parametrize("soft_cap", [None]) +@pytest.mark.parametrize("has_sinks", HAS_SINKS) @torch.inference_mode def test_flashinfer_trtllm_prefill_with_baseline( dtype: torch.dtype, - quant_dtypes: tuple[ - Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] - ], + quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None], batch_size: int, max_seq_lens: tuple[int, int], num_heads: tuple[int, int], @@ -254,10 +272,11 @@ def test_flashinfer_trtllm_prefill_with_baseline( kv_layout: str, block_size: int, window_left: int, - soft_cap: Optional[float], + soft_cap: float | None, + has_sinks: bool, ) -> None: torch.set_default_device("cuda") - current_platform.seed_everything(0) + current_platform.seed_everything(42) q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes q_quant_dtype = q_quant_dtype or dtype @@ -299,7 +318,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( q_scale = 1.0 ref_query = query - kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32) + kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32) kv_lens[-1] = max_kv_len seq_lens = kv_lens + q_lens @@ -338,36 +357,43 @@ def test_flashinfer_trtllm_prefill_with_baseline( workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) # Baseline Prefill - wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout - ) + if has_sinks: + sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5 + wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper( + float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2" + ) + else: + sinks = None + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2" + ) + wrapper.plan( - q_indptr, - kv_indptr, - kv_indices, - kv_last_page_lens, - num_qo_heads, - num_kv_heads, - head_size, - block_size, + qo_indptr=q_indptr, + paged_kv_indptr=kv_indptr, + paged_kv_indices=kv_indices, + paged_kv_last_page_len=kv_last_page_lens, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_size, + page_size=block_size, causal=True, sm_scale=sm_scale, - q_data_type=dtype, - kv_data_type=dtype, window_left=window_left, logits_soft_cap=soft_cap, + q_data_type=dtype, + kv_data_type=dtype, ) - output = torch.empty(ref_query.shape, dtype=dtype) - wrapper.run(ref_query, ref_kv_cache, out=output) + wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output) + o_scale = 1.0 - o_sf_scale = None + o_sf_scale_float = None if o_quant_dtype == FP8_DTYPE: _, o_scale = to_float8(output) elif o_quant_dtype == FP4_DTYPE: - o_sf_scale = ( - (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1) - ).to(torch.float32) + o_sf_scale = get_nvfp4_global_scale(output) + o_sf_scale_float = o_sf_scale.item() # TRTLLM Prefill if o_quant_dtype == FP4_DTYPE: @@ -398,7 +424,8 @@ def test_flashinfer_trtllm_prefill_with_baseline( cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, window_left=window_left, - o_sf_scale=o_sf_scale, + sinks=sinks, + o_sf_scale=o_sf_scale_float, out=output_trtllm, ) if o_quant_dtype == FP8_DTYPE: @@ -413,11 +440,11 @@ def test_flashinfer_trtllm_prefill_with_baseline( output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2]) if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: - rtol, atol = 4e-1, 1e0 + rtol, atol = 1e-1, 2e-1 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: - rtol, atol = 5e-2, 7e-2 - elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype: rtol, atol = 4e-2, 6e-2 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype: + rtol, atol = 2e-2, 3e-2 else: rtol, atol = 1e-2, 1e-2 diff --git a/tests/kernels/attention/test_merge_attn_states.py b/tests/kernels/attention/test_merge_attn_states.py index eb9204dfaf15..9b084f2f660b 100644 --- a/tests/kernels/attention/test_merge_attn_states.py +++ b/tests/kernels/attention/test_merge_attn_states.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import pytest import torch @@ -20,7 +19,7 @@ def merge_attn_states_torch( prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] - output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS] + output_lse: torch.Tensor | None = None, # [NUM_HEADS, NUM_TOKENS] ): p_lse = prefix_lse s_lse = suffix_lse diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py index 5ff2624cd7a4..65972d02f2f6 100644 --- a/tests/kernels/attention/test_prefix_prefill.py +++ b/tests/kernels/attention/test_prefix_prefill.py @@ -15,7 +15,7 @@ from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 64] diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index a59230528770..9b7fb664956c 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -18,7 +18,7 @@ def clear_cache(): @pytest.mark.skip(reason="Skipped for now. Should be revisited.") def test_selector(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH") + m.setenv(STR_BACKEND_ENV_VAR, "ROCM_ATTN") # Set the current platform to ROCm using monkeypatch monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform()) diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index fba82cfdadbd..bf4d2179af5f 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import pytest import torch @@ -32,8 +31,8 @@ def ref_paged_attn( kv_lens: list[int], block_tables: torch.Tensor, scale: float, - sliding_window: Optional[int] = None, - soft_cap: Optional[float] = None, + sliding_window: int | None = None, + soft_cap: float | None = None, ) -> torch.Tensor: num_seqs = len(query_lens) block_tables = block_tables.cpu().numpy() @@ -98,12 +97,12 @@ def test_triton_unified_attn( seq_lens: list[tuple[int, int]], num_heads: tuple[int, int], head_size: int, - sliding_window: Optional[int], + sliding_window: int | None, dtype: torch.dtype, block_size: int, - soft_cap: Optional[float], + soft_cap: float | None, num_blocks: int, - q_dtype: Optional[torch.dtype], + q_dtype: torch.dtype | None, ) -> None: torch.set_default_device("cuda") diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index 52133ec53d1d..63b5a37d3c77 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union import pytest import torch @@ -16,7 +15,6 @@ # Avoid combinatorial explosion with full Cartesian product NUM_TOKENS_HIDDEN_SIZES = [ *[(1, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5120, 5137]], - *[(83, i) for i in [1, 1033, 2048, 5120]], *[(2048, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5137]], *[(4096, i) for i in [1, 64, 5137]], ] @@ -31,13 +29,13 @@ ## Helpers -def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: +def as_float32_tensor(x: float | torch.Tensor) -> torch.Tensor: return torch.as_tensor(x, dtype=torch.float32, device="cuda") def ref_rms_norm( - rms_norm_layer: RMSNorm, x: torch.Tensor, residual: Optional[torch.Tensor] -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor | None +) -> tuple[torch.Tensor, torch.Tensor | None]: if residual is not None: residual = residual.clone() out, residual = rms_norm_layer.forward_native(x, residual) @@ -51,9 +49,9 @@ def ref_dynamic_per_token_quant( rms_norm_layer: RMSNorm, x: torch.Tensor, quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor], -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + residual: torch.Tensor | None, + scale_ub: torch.Tensor | None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: if scale_ub is not None: assert quant_dtype == torch.float8_e4m3fn @@ -76,9 +74,9 @@ def ref_impl( rms_norm_layer: RMSNorm, x: torch.Tensor, quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor], -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + residual: torch.Tensor | None, + scale_ub: torch.Tensor | None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: return ref_dynamic_per_token_quant( rms_norm_layer, x, quant_dtype, residual, scale_ub ) @@ -88,9 +86,9 @@ def ops_dynamic_per_token_quant( weight: torch.Tensor, x: torch.Tensor, quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor], -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + residual: torch.Tensor | None, + scale_ub: torch.Tensor | None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: if residual is not None: residual = residual.clone() out, scales = ops.rms_norm_dynamic_per_token_quant( @@ -103,9 +101,9 @@ def ops_impl( weight: torch.Tensor, x: torch.Tensor, quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor], -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + residual: torch.Tensor | None, + scale_ub: torch.Tensor | None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, scale_ub) diff --git a/tests/kernels/core/test_layernorm.py b/tests/kernels/core/test_layernorm.py index 7553d45e0057..49bd77f6795f 100644 --- a/tests/kernels/core/test_layernorm.py +++ b/tests/kernels/core/test_layernorm.py @@ -6,24 +6,12 @@ from tests.kernels.quant_utils import FP8_DTYPE from tests.kernels.utils import opcheck -from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing -HIDDEN_SIZES = [ - 8, - 768, - 769, - 770, - 771, - 5120, - 5124, - 5125, - 5126, - 8192, - 8199, -] # Arbitrary values for testing +HIDDEN_SIZES = [8, 768, 769, 5120, 5125, 8192] # Arbitrary values for testing ADD_RESIDUAL = [False, True] SEEDS = [0] CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] @@ -82,43 +70,11 @@ def test_rms_norm( ) -@pytest.mark.parametrize("num_tokens", NUM_TOKENS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_poly_norm( - num_tokens: int, - hidden_size: int, - dtype: torch.dtype, - seed: int, - device: str, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - layer = PolyNorm().to(dtype=dtype) - layer.weight.data.normal_(mean=1.0, std=0.1) - layer.bias.data.normal_(mean=1.0, std=0.1) - scale = 1 / (2 * hidden_size) - x = torch.randn(num_tokens, hidden_size, dtype=dtype) - x *= scale - - ref_out = layer.forward_native(x) - out = layer(x) - torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) - - opcheck( - torch.ops._C.poly_norm, - (out, x, layer.weight.data, layer.bias.data, layer.variance_epsilon), - ) - - @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("add_residual", ADD_RESIDUAL) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0]) +@pytest.mark.parametrize("quant_scale", [0.01, 1.0, 10.0]) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("strided_input", [False, True]) diff --git a/tests/kernels/core/test_permute_cols.py b/tests/kernels/core/test_permute_cols.py index 1e264735cb3c..08fdd0e055ea 100644 --- a/tests/kernels/core/test_permute_cols.py +++ b/tests/kernels/core/test_permute_cols.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize("shape", [(1, 512), (544, 4096), (67, 8192)]) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) def test_permute_cols(shape, dtype): x = torch.randn(shape, dtype=dtype).cuda() perm = torch.randperm(x.shape[1]).to(torch.int).cuda() diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index 799e0a3f2a2b..c35ee5016ba0 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable from itertools import product -from typing import Callable, Optional import pytest import torch @@ -12,8 +12,8 @@ from vllm.platforms import current_platform IS_NEOX_STYLE = [True, False] -DTYPES = [torch.half, torch.bfloat16, torch.float] -HEAD_SIZES = [64, 80, 112, 120, 256] +DTYPES = [torch.bfloat16, torch.float] +HEAD_SIZES = [64, 80, 120, 256] ROTARY_DIMS = [None, 32] # None means rotary dim == head size NUM_HEADS = [17] # Arbitrary values for testing BATCH_SIZES = [5] # Arbitrary values for testing @@ -68,7 +68,7 @@ def test_rotary_embedding( seq_len: int, num_heads: int, head_size: int, - rotary_dim: Optional[int], + rotary_dim: int | None, dtype: torch.dtype, seed: int, device: str, diff --git a/tests/kernels/core/test_rotary_embedding.py b/tests/kernels/core/test_rotary_embedding.py index 0a292a3e2ae7..30c64e0bd72a 100644 --- a/tests/kernels/core/test_rotary_embedding.py +++ b/tests/kernels/core/test_rotary_embedding.py @@ -4,8 +4,6 @@ Tests for miscellaneous utilities """ -from typing import Optional - import pytest import torch @@ -17,7 +15,7 @@ def rotary_embedding_opcheck( rot, positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor] = None, + key: torch.Tensor | None = None, ): cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype) diff --git a/tests/kernels/core/test_uva.py b/tests/kernels/core/test_uva.py index 73738175e5c7..2690346af4d3 100644 --- a/tests/kernels/core/test_uva.py +++ b/tests/kernels/core/test_uva.py @@ -3,7 +3,8 @@ import pytest import torch -from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available +from vllm.utils.platform_utils import is_uva_available +from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] diff --git a/tests/kernels/mamba/test_causal_conv1d.py b/tests/kernels/mamba/test_causal_conv1d.py index fea6b94481b6..4647b97c4771 100644 --- a/tests/kernels/mamba/test_causal_conv1d.py +++ b/tests/kernels/mamba/test_causal_conv1d.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import pytest import torch @@ -19,11 +18,11 @@ def causal_conv1d_ref( x: torch.Tensor, weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - initial_states: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, + initial_states: torch.Tensor | None = None, return_final_states: bool = False, - final_states_out: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", + final_states_out: torch.Tensor | None = None, + activation: str | None = "silu", ): """ x: (batch, dim, seqlen) @@ -117,12 +116,12 @@ def causal_conv1d_update_ref( def causal_conv1d_opcheck_fn( x: torch.Tensor, weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - cu_seq_len: Optional[torch.Tensor] = None, - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - conv_states: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", + bias: torch.Tensor | None = None, + cu_seq_len: torch.Tensor | None = None, + cache_indices: torch.Tensor | None = None, + has_initial_state: torch.Tensor | None = None, + conv_states: torch.Tensor | None = None, + activation: str | None = "silu", pad_slot_id: int = PAD_SLOT_ID, ): """ @@ -184,7 +183,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) -@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("silu_activation", [False, True]) @pytest.mark.parametrize("has_bias", [False, True]) @pytest.mark.parametrize("seqlen", [1, 3]) @@ -266,7 +265,7 @@ def test_causal_conv1d_update_with_batch_gather( @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize("seqlen", [8, 30, 249, 2049, 4096]) +@pytest.mark.parametrize("seqlen", [8, 249, 4096]) @pytest.mark.parametrize("dim", [64, 4096]) @pytest.mark.parametrize("with_padding", [True, False]) @pytest.mark.parametrize("batch", [4, 10]) diff --git a/tests/kernels/mamba/test_mamba_mixer2.py b/tests/kernels/mamba/test_mamba_mixer2.py index d23daefa7b43..6fca33acd48a 100644 --- a/tests/kernels/mamba/test_mamba_mixer2.py +++ b/tests/kernels/mamba/test_mamba_mixer2.py @@ -13,7 +13,7 @@ ) from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated from vllm.platforms import current_platform -from vllm.utils import update_environment_variables +from vllm.utils.system_utils import update_environment_variables @multi_gpu_test(num_gpus=2) @@ -25,7 +25,6 @@ (64, 1), (64, 2), (64, 4), # hidden_size be divisible by num_gpus - (100, 5), # and n_groups must divide hidden_size ], ) @pytest.mark.parametrize("dtype", [torch.float16]) diff --git a/tests/kernels/mamba/test_mamba_ssm.py b/tests/kernels/mamba/test_mamba_ssm.py index 9a6137239ebf..c59fc7af0c89 100644 --- a/tests/kernels/mamba/test_mamba_ssm.py +++ b/tests/kernels/mamba/test_mamba_ssm.py @@ -229,8 +229,8 @@ def selective_scan_opcheck_fn( @pytest.mark.parametrize("wtype", [torch.float32]) -@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("seqlen", [128, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("seqlen", [128, 1024, 4096]) @pytest.mark.parametrize("has_delta_bias", [True]) @pytest.mark.parametrize("delta_softplus", [True]) @pytest.mark.parametrize("has_z", [True]) @@ -238,7 +238,7 @@ def selective_scan_opcheck_fn( @pytest.mark.parametrize("varBC_groups", [1, 2]) @pytest.mark.parametrize("is_variable_C", [True]) @pytest.mark.parametrize("is_variable_B", [True]) -@pytest.mark.parametrize("scan_chunks", [1, 2, 3]) +@pytest.mark.parametrize("scan_chunks", [1, 3]) def test_selective_scan( is_variable_B, is_variable_C, @@ -375,9 +375,9 @@ def test_selective_scan( ) -@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("has_z", [False, True]) -@pytest.mark.parametrize("dstate", [16, 32, 64]) +@pytest.mark.parametrize("dstate", [16, 64]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) def test_selective_state_update(dim, dstate, has_z, itype): device = "cuda" @@ -413,7 +413,7 @@ def test_selective_state_update(dim, dstate, has_z, itype): @pytest.mark.parametrize("wtype", [torch.float32]) @pytest.mark.parametrize("itype", [torch.float32]) -@pytest.mark.parametrize("seqlen", [1, 128, 129, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("seqlen", [1, 256, 1024, 4096]) @pytest.mark.parametrize("return_last_state", [True]) @pytest.mark.parametrize("has_delta_bias", [True]) @pytest.mark.parametrize("delta_softplus", [True]) @@ -589,9 +589,9 @@ def test_selective_scan_varlen( ) -@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("has_z", [True]) -@pytest.mark.parametrize("dstate", [16, 32, 64]) +@pytest.mark.parametrize("dstate", [16, 64]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) # tests correctness in case subset of the sequences are padded @pytest.mark.parametrize("with_padding", [True, False]) @@ -679,11 +679,11 @@ def test_selective_state_update_with_batch_indices( assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) -@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("has_z", [False, True]) @pytest.mark.parametrize("tie_hdim", [False, True]) -@pytest.mark.parametrize("ngroups", [1, 2, 4]) -@pytest.mark.parametrize("dstate", [16, 32, 64]) +@pytest.mark.parametrize("ngroups", [1, 4]) +@pytest.mark.parametrize("dstate", [16, 64]) @pytest.mark.parametrize("dim", [2048, 4096]) def test_selective_state_update_with_heads_with_batch_indices( dim, dstate, ngroups, has_z, tie_hdim, itype diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 57dcb789e97b..0b0b82e484a1 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -188,9 +188,9 @@ def end_boundary(n: int): ) -@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32]) -@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128]) +@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("n_heads", [4, 16, 32]) +@pytest.mark.parametrize("d_head", [5, 8, 32, 128]) @pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)]) def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype): # this tests the kernels on a single example (bs=1) @@ -254,15 +254,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, it ) -@pytest.mark.parametrize("itype", [torch.float32, torch.float16]) -@pytest.mark.parametrize("n_heads", [4, 8, 13]) -@pytest.mark.parametrize("d_head", [5, 16, 21, 32]) +@pytest.mark.parametrize("itype", [torch.float32]) +@pytest.mark.parametrize("n_heads", [4, 8]) +@pytest.mark.parametrize("d_head", [5, 16, 32]) @pytest.mark.parametrize( "seq_len_chunk_size_cases", [ # small-ish chunk_size (8) (64, 8, 2, [(64, 32), (64, 32)]), - (64, 8, 2, [(32, 32), (32, 32), (32, 32)]), (64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary ( 64, @@ -270,16 +269,7 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, it 2, [(4, 4), (4, 4), (4, 4), (4, 4)], ), # chunk_size larger than cont batches - ( - 64, - 8, - 5, - [ - (64, 32, 16, 8, 8), - (8, 16, 32, 16, 8), - (8, 8, 16, 32, 16), - ], - ), # mode examples with varied lengths + (64, 8, 5, [(64, 32, 16, 8, 8)]), # large-ish chunk_size (256) (64, 256, 1, [(5,), (1,), (1,), (1,)]), # irregular sizes with small sequences ( @@ -359,11 +349,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, @pytest.mark.parametrize("chunk_size", [8, 256]) @pytest.mark.parametrize( "seqlens", - [ - (16, 2, 8, 13), - (270, 88, 212, 203), - (16, 20), - ], + [(16, 20), (270, 88, 212, 203)], ) def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): # This test verifies the correctness of the chunked prefill implementation diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index 091fa4fafe21..c517e5c026b4 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any import torch @@ -23,7 +23,7 @@ FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx +from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx from .mk_objects import ( TestMoEQuantConfig, @@ -35,7 +35,7 @@ from .parallel_utils import ProcessGroupInfo -def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str: +def _describe_tensor(t: torch.Tensor | None, name: str) -> str: if t is None: return f"{name} : None" else: @@ -44,21 +44,21 @@ def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str: @dataclass class Config: - Ms: Union[list[int], int] + Ms: list[int] | int K: int N: int E: int - topks: Union[list[int], int] + topks: list[int] | int dtype: torch.dtype - quant_config: Optional[TestMoEQuantConfig] + quant_config: TestMoEQuantConfig | None prepare_finalize_type: mk.FusedMoEPrepareAndFinalize fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute - fused_moe_chunk_size: Optional[int] + fused_moe_chunk_size: int | None world_size: int - torch_trace_dir_path: Optional[str] = None + torch_trace_dir_path: str | None = None def __post_init__(self): if self.quant_config is None: @@ -93,7 +93,7 @@ def M(self) -> int: return self.Ms @property - def quant_dtype(self) -> Union[torch.dtype, str, None]: + def quant_dtype(self) -> torch.dtype | str | None: assert self.quant_config is not None return self.quant_config.quant_dtype @@ -112,7 +112,7 @@ def is_per_out_ch_quant(self) -> bool: return self.quant_config.per_out_ch_quant @property - def quant_block_shape(self) -> Optional[list[int]]: + def quant_block_shape(self) -> list[int] | None: assert self.quant_config is not None return self.quant_config.block_shape @@ -209,18 +209,18 @@ def all2all_backend(self): info = prepare_finalize_info(self.prepare_finalize_type) return info.backend - def is_valid(self): + def is_valid(self) -> tuple[bool, str | None]: # Check prepare-finalize and fused-experts compatibility if self.is_batched_prepare_finalize(): if not self.is_batched_fused_experts(): - return False + return False, "Mismatched format." else: if not self.is_standard_fused_experts(): - return False + return False, "Mismatched format." use_chunking = self.fused_moe_chunk_size is not None if use_chunking and not self.is_fe_supports_chunking(): - return False + return False, "Chunking not supported." # Check quantization sanity if ( @@ -229,7 +229,7 @@ def is_valid(self): + int(self.quant_block_shape is not None) ) > 1: # invalid quant config - return False + return False, f"Bad quant_config {self.quant_config}." # check type support if self.quant_dtype is None: @@ -237,44 +237,53 @@ def is_valid(self): self.dtype not in self.pf_supported_types() or self.dtype not in self.fe_supported_types() ): - return False + return False, ( + f"Unsupported type {self.dtype} not in " + f"{self.pf_supported_types()} and " + f"{self.fe_supported_types()}." + ) else: if ( self.quant_dtype not in self.pf_supported_types() or self.quant_dtype not in self.fe_supported_types() ): - return False + return False, ( + f"Unsupported quant type {self.quant_dtype} " + f"not in {self.pf_supported_types()} and " + f"{self.fe_supported_types()}." + ) # Check block quanization support is_block_quatized = self.quant_block_shape is not None if is_block_quatized and self.quant_dtype is None: - return False + return False, "No block quantization support." + if is_block_quatized and not self.is_block_quant_supported(): - return False + return False, "Mismatched block quantization support." # deep_gemm only works with block-quantized if self.needs_deep_gemm() and not is_block_quatized: - return False + return False, "Needs DeepGEMM but not block quantized." # Check dependencies (turn into asserts?) if self.needs_deep_ep() and not has_deep_ep(): - return False + return False, "Needs DeepEP, but DeepEP not available." if self.needs_deep_gemm() and not has_deep_gemm(): - return False + return False, "Needs DeepGEMM, but DeepGEMM not available." if self.needs_pplx() and not has_pplx(): # noqa: SIM103 - return False + return False, "Needs PPLX, but PPLX not available." - return True + return True, None @dataclass class WeightTensors: w1: torch.Tensor w2: torch.Tensor - w1_scale: Optional[torch.Tensor] - w2_scale: Optional[torch.Tensor] - w1_gs: Optional[torch.Tensor] = None - w2_gs: Optional[torch.Tensor] = None + w1_scale: torch.Tensor | None + w2_scale: torch.Tensor | None + w1_gs: torch.Tensor | None = None + w2_gs: torch.Tensor | None = None def describe(self): s = "" @@ -342,11 +351,11 @@ def make(config: Config) -> "WeightTensors": @dataclass class RankTensors: hidden_states: torch.Tensor - hidden_states_scale: Optional[torch.Tensor] + hidden_states_scale: torch.Tensor | None topk_weights: torch.Tensor topk_ids: torch.Tensor - expert_map: Optional[torch.Tensor] + expert_map: torch.Tensor | None def describe(self): s = "" @@ -361,7 +370,7 @@ def describe(self): @staticmethod def make_hidden_states( config: Config, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Return hidden_states """ diff --git a/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py b/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py index 0ef306051c8a..95db6327c4f1 100644 --- a/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py +++ b/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py @@ -4,7 +4,6 @@ import copy from enum import Enum from itertools import product -from typing import Optional import torch from tqdm import tqdm @@ -82,7 +81,7 @@ def make_feature_matrix(csv_file_path: str): import pandas as pd def add_to_results( - config: Config, success: Result, results_df: Optional[pd.DataFrame] = None + config: Config, success: Result, results_df: pd.DataFrame | None = None ): config_dict = asdict(config) config_dict["prepare_finalize_type"] = config_dict[ @@ -121,7 +120,7 @@ def add_to_results( product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES) ) - results_df: Optional[pd.DataFrame] = None + results_df: pd.DataFrame | None = None for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm( combinations ): @@ -140,7 +139,7 @@ def add_to_results( ) success = None - if config.is_valid(): + if config.is_valid()[0]: print(f"Running config : {config.describe()} ...") try: weights: WeightTensors = WeightTensors.make(config) diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 566fb1e09d3b..21eeffb1c726 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional, Union import torch @@ -36,32 +35,32 @@ cutlass_fp8_supported, ) from vllm.platforms import current_platform -from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.deep_gemm import is_deep_gemm_supported from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx @dataclass class TestMoEQuantConfig: - quant_dtype: Union[torch.dtype, str, None] + quant_dtype: torch.dtype | str | None per_out_ch_quant: bool per_act_token_quant: bool - block_shape: Optional[list[int]] + block_shape: list[int] | None @dataclass class PrepareFinalizeInfo: activation_format: mk.FusedMoEActivationFormat - supported_dtypes: list[Union[torch.dtype, str]] + supported_dtypes: list[torch.dtype | str] blocked_quantization_support: bool - backend: Optional[str] + backend: str | None supports_apply_weight_on_input: bool = True @dataclass class ExpertInfo: activation_format: mk.FusedMoEActivationFormat - supported_dtypes: list[Union[torch.dtype, str]] + supported_dtypes: list[torch.dtype | str] blocked_quantization_support: bool supports_chunking: bool supports_expert_map: bool @@ -78,7 +77,7 @@ class ExpertInfo: standard_format = mk.FusedMoEActivationFormat.Standard batched_format = mk.FusedMoEActivationFormat.BatchedExperts -common_float_types: list[Union[torch.dtype, str]] = [ +common_float_types: list[torch.dtype | str] = [ torch.float8_e4m3fn, torch.bfloat16, torch.float16, @@ -92,9 +91,9 @@ class ExpertInfo: def register_prepare_and_finalize( kind, activation_format: mk.FusedMoEActivationFormat, - supported_dtypes: list[Union[torch.dtype, str]], + supported_dtypes: list[torch.dtype | str], blocked_quantization_support: bool, - backend: Optional[str], + backend: str | None, force_multigpu: bool = False, supports_apply_weight_on_input: bool = True, ): @@ -121,7 +120,7 @@ def register_prepare_and_finalize( def register_experts( kind, activation_format: mk.FusedMoEActivationFormat, - supported_dtypes: list[Union[torch.dtype, str]], + supported_dtypes: list[torch.dtype | str], blocked_quantization_support: bool, supports_chunking: bool, supports_expert_map: bool, @@ -244,7 +243,7 @@ def expert_info(kind) -> ExpertInfo: register_prepare_and_finalize( FlashInferCutlassMoEPrepareAndFinalize, standard_format, - nvfp4_types, + nvfp4_types + fp8_types, blocked_quantization_support=True, backend=None, force_multigpu=True, @@ -254,7 +253,7 @@ def expert_info(kind) -> ExpertInfo: register_experts( FlashInferExperts, standard_format, - nvfp4_types, + nvfp4_types + fp8_types, blocked_quantization_support=True, supports_chunking=True, # Note: this is a hack to get it to run for now @@ -274,17 +273,15 @@ def expert_info(kind) -> ExpertInfo: needs_matching_quant=False, needs_deep_gemm=True, ) - ( - register_experts( - DeepGemmExperts, - standard_format, - fp8_types, - blocked_quantization_support=True, - supports_chunking=True, - supports_expert_map=True, - needs_matching_quant=False, - needs_deep_gemm=True, - ), + register_experts( + DeepGemmExperts, + standard_format, + fp8_types, + blocked_quantization_support=True, + supports_chunking=True, + supports_expert_map=True, + needs_matching_quant=False, + needs_deep_gemm=True, ) register_experts( BatchedTritonOrDeepGemmExperts, @@ -342,7 +339,7 @@ def expert_info(kind) -> ExpertInfo: supports_expert_map=False, ) -MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [ +MK_QUANT_CONFIGS: list[TestMoEQuantConfig | None] = [ None, # per-channel / per-column weights and per-tensor activations TestMoEQuantConfig( @@ -397,7 +394,7 @@ def expert_info(kind) -> ExpertInfo: def make_prepare_finalize( prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, - backend: Optional[str], + backend: str | None, moe: FusedMoEConfig, quant_config: FusedMoEQuantConfig, ) -> mk.FusedMoEPrepareAndFinalize: @@ -464,7 +461,7 @@ def make_fused_experts( print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...") experts = BatchedTritonOrDeepGemmExperts(**kwargs) elif fused_experts_type == DeepGemmExperts: - print("Making DeepGemmExperts {quant_config} ...") + print(f"Making DeepGemmExperts {quant_config} ...") experts = DeepGemmExperts(quant_config) elif fused_experts_type == TritonExperts: kwargs = quant_kwargs diff --git a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py index 7802129d3d48..8528ee0cdee6 100644 --- a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py +++ b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py @@ -3,15 +3,16 @@ import dataclasses import os import traceback -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Concatenate import torch from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] -from typing_extensions import Concatenate, ParamSpec +from typing_extensions import ParamSpec from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed import init_distributed_environment, initialize_model_parallel -from vllm.utils import get_open_port +from vllm.utils.network_utils import get_open_port ## Parallel Processes Utils @@ -58,9 +59,9 @@ def _worker_parallel_launch( world_local_size: int, node_rank: int, init_method: str, - worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any, P], None], - vllm_config: Optional[VllmConfig], - env_dict: Optional[dict], + worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig | None, Any, P], None], + vllm_config: VllmConfig | None, + env_dict: dict | None, *args: P.args, **kwargs: P.kwargs, ) -> None: diff --git a/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py index 48e5c4659b49..a3e264c5f5e2 100644 --- a/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py +++ b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +from collections.abc import Callable from itertools import product -from typing import Any, Callable +from typing import Any import torch diff --git a/tests/kernels/moe/parallel_utils.py b/tests/kernels/moe/parallel_utils.py index fb9e5df281f1..90728c1e30a4 100644 --- a/tests/kernels/moe/parallel_utils.py +++ b/tests/kernels/moe/parallel_utils.py @@ -7,14 +7,16 @@ import dataclasses import os import traceback -from typing import Callable, Optional +from collections.abc import Callable +from typing import Concatenate import torch from torch.distributed import ProcessGroup from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] -from typing_extensions import Concatenate, ParamSpec +from typing_extensions import ParamSpec -from vllm.utils import get_open_port, has_deep_ep +from vllm.utils.import_utils import has_deep_ep +from vllm.utils.network_utils import get_open_port if has_deep_ep(): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( @@ -126,8 +128,8 @@ def make_deepep_ht_a2a( pgi: ProcessGroupInfo, dp_size: int, ht_args: DeepEPHTArgs, - q_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None, + q_dtype: torch.dtype | None = None, + block_shape: list[int] | None = None, ): import deep_ep @@ -153,8 +155,8 @@ def make_deepep_ll_a2a( pg: ProcessGroup, pgi: ProcessGroupInfo, deepep_ll_args: DeepEPLLArgs, - q_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None, + q_dtype: torch.dtype | None = None, + block_shape: list[int] | None = None, ): import deep_ep @@ -185,10 +187,10 @@ def make_deepep_a2a( pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, - deepep_ht_args: Optional[DeepEPHTArgs], - deepep_ll_args: Optional[DeepEPLLArgs], - q_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None, + deepep_ht_args: DeepEPHTArgs | None, + deepep_ll_args: DeepEPLLArgs | None, + q_dtype: torch.dtype | None = None, + block_shape: list[int] | None = None, ): if deepep_ht_args is not None: assert deepep_ll_args is None diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 09cede3fbcc7..2dce099770f0 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional import pytest import torch @@ -55,7 +54,7 @@ @dataclass class BatchedMMConfig: in_dtype: torch.dtype - quant_dtype: Optional[torch.dtype] + quant_dtype: torch.dtype | None out_dtype: torch.dtype num_experts: int max_tokens_per_expert: int @@ -115,7 +114,7 @@ def test_batched_mm( K: int, N: int, dtype: torch.dtype, - block_shape: Optional[list[int]], + block_shape: list[int] | None, per_act_token_quant: bool, ): current_platform.seed_everything(7) @@ -242,7 +241,7 @@ def test_fused_moe_batched_experts( topk: int, dtype: torch.dtype, per_act_token_quant: bool, - block_shape: Optional[list[int]], + block_shape: list[int] | None, input_scales: bool, ): current_platform.seed_everything(7) diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index b8cd3cb9200c..60f9f14b7f6f 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -21,14 +21,14 @@ modular_triton_fused_moe, ) from vllm.platforms import current_platform -from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import ( + get_mk_alignment_for_contiguous_layout, + is_deep_gemm_e8m0_used, +) +from vllm.utils.import_utils import has_deep_gemm dg_available = has_deep_gemm() -if dg_available: - from deep_gemm import get_m_alignment_for_contiguous_layout - if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) @@ -218,8 +218,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch) torch.manual_seed(seed) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) - block_m = get_m_alignment_for_contiguous_layout() - block_size = [block_m, block_m] + block_size = get_mk_alignment_for_contiguous_layout() dtype = torch.bfloat16 a = torch.randn((M, K), dtype=dtype) / 10 diff --git a/tests/kernels/moe/test_count_expert_num_tokens.py b/tests/kernels/moe/test_count_expert_num_tokens.py index 996a4538d105..39138be83bcc 100644 --- a/tests/kernels/moe/test_count_expert_num_tokens.py +++ b/tests/kernels/moe/test_count_expert_num_tokens.py @@ -5,7 +5,6 @@ """ import dataclasses -from typing import Optional import pytest import torch @@ -16,7 +15,7 @@ @dataclasses.dataclass class TestTensors: topk_ids: torch.Tensor - expert_map: Optional[torch.Tensor] = None + expert_map: torch.Tensor | None = None def to_device(self, device: str): self.topk_ids = self.topk_ids.to(device=device) diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index b82cea61bd4e..4330eda251f7 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -3,7 +3,6 @@ import copy import dataclasses from math import prod -from typing import Optional import pytest import torch @@ -85,16 +84,16 @@ def make_moe_tensors( @dataclasses.dataclass class MOETensors8Bit(MOETensors): # quantized - a_q: Optional[torch.Tensor] = None # a -> a_q - w1_q: Optional[torch.Tensor] = None # w1 -> w1_q - w2_q: Optional[torch.Tensor] = None # w2 -> w2_q - a_scale: Optional[torch.Tensor] = None - w1_scale: Optional[torch.Tensor] = None - w2_scale: Optional[torch.Tensor] = None + a_q: torch.Tensor | None = None # a -> a_q + w1_q: torch.Tensor | None = None # w1 -> w1_q + w2_q: torch.Tensor | None = None # w2 -> w2_q + a_scale: torch.Tensor | None = None + w1_scale: torch.Tensor | None = None + w2_scale: torch.Tensor | None = None # dequantized - a_d: Optional[torch.Tensor] = None # a -> a_q -> a_d - w1_d: Optional[torch.Tensor] = None # w1 -> w1_q -> w1_d - w2_d: Optional[torch.Tensor] = None # w2 -> w2_q -> w2_d + a_d: torch.Tensor | None = None # a -> a_q -> a_d + w1_d: torch.Tensor | None = None # w1 -> w1_q -> w1_d + w2_d: torch.Tensor | None = None # w2 -> w2_q -> w2_d @staticmethod def make_moe_tensors_8bit( @@ -209,7 +208,7 @@ def run_8_bit( topk_ids: torch.Tensor, per_act_token: bool, per_out_ch: bool, - num_local_experts: Optional[int] = None, + num_local_experts: int | None = None, ) -> torch.Tensor: assert not any( [ @@ -280,7 +279,7 @@ def test_cutlass_moe_8_bit_no_graph( per_act_token: bool, per_out_ch: bool, monkeypatch, - ep_size: Optional[int] = None, + ep_size: int | None = None, ): current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index e68c5bfa5946..d46f453488a9 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -7,7 +7,6 @@ """ import dataclasses -from typing import Optional import pytest import torch.distributed @@ -22,8 +21,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.platforms import current_platform -from vllm.utils import has_deep_ep, has_deep_gemm from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported +from vllm.utils.import_utils import has_deep_ep, has_deep_gemm from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch @@ -92,13 +91,13 @@ class TestConfig: block_size: list[int] # configs for testing low-latency kernels low_latency: bool - use_fp8_dispatch: Optional[bool] = False + use_fp8_dispatch: bool | None = False @dataclasses.dataclass class TestTensors: rank_tokens: torch.Tensor # all ranks make this many tokens - rank_token_scales: Optional[torch.Tensor] + rank_token_scales: torch.Tensor | None topk: torch.Tensor topk_weights: torch.Tensor config: TestConfig @@ -143,7 +142,7 @@ def make_ll_modular_kernel( max_tokens_per_rank: int, dp_size: int, hidden_size: int, - q_dtype: Optional[torch.dtype], + q_dtype: torch.dtype | None, test_config: TestConfig, quant_config: FusedMoEQuantConfig, ) -> FusedMoEModularKernel: @@ -179,7 +178,7 @@ def make_ht_modular_kernel( pgi: ProcessGroupInfo, dp_size: int, num_local_experts: int, - q_dtype: Optional[torch.dtype], + q_dtype: torch.dtype | None, test_config: TestConfig, quant_config: FusedMoEQuantConfig, ) -> FusedMoEModularKernel: @@ -249,8 +248,8 @@ def deepep_deepgemm_moe_impl( test_tensors: TestTensors, w1: torch.Tensor, w2: torch.Tensor, - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], + w1_scale: torch.Tensor | None, + w2_scale: torch.Tensor | None, ) -> torch.Tensor: test_config = test_tensors.config num_experts = test_config.num_experts diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index a1dabea1f0c7..b49319a7e6f5 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -5,7 +5,6 @@ """ import dataclasses -from typing import Optional, Union import pytest import torch.distributed @@ -22,7 +21,7 @@ per_token_group_quant_fp8, ) from vllm.platforms import current_platform -from vllm.utils import has_deep_ep +from vllm.utils.import_utils import has_deep_ep from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch @@ -90,7 +89,7 @@ class TestConfig: @dataclasses.dataclass class TestTensors: rank_tokens: torch.Tensor # all ranks make this many tokens - rank_token_scales: Optional[torch.Tensor] + rank_token_scales: torch.Tensor | None topk: torch.Tensor topk_weights: torch.Tensor config: TestConfig @@ -128,12 +127,12 @@ def make_modular_kernel( dp_size: int, num_experts: int, num_local_experts: int, - q_dtype: Optional[torch.dtype], + q_dtype: torch.dtype | None, use_fp8_dispatch: bool, quant_config: FusedMoEQuantConfig, ) -> FusedMoEModularKernel: - ht_args: Optional[DeepEPHTArgs] = None - ll_args: Optional[DeepEPLLArgs] = None + ht_args: DeepEPHTArgs | None = None + ll_args: DeepEPLLArgs | None = None if low_latency_mode: ll_args = DeepEPLLArgs( @@ -148,16 +147,14 @@ def make_modular_kernel( ) ht_args = DeepEPHTArgs(num_local_experts=num_local_experts) - a2a: Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = ( - make_deepep_a2a( - pg=pg, - pgi=pgi, - dp_size=dp_size, - q_dtype=q_dtype, - block_shape=None, - deepep_ht_args=ht_args, - deepep_ll_args=ll_args, - ) + a2a: DeepEPHTPrepareAndFinalize | DeepEPLLPrepareAndFinalize = make_deepep_a2a( + pg=pg, + pgi=pgi, + dp_size=dp_size, + q_dtype=q_dtype, + block_shape=None, + deepep_ht_args=ht_args, + deepep_ll_args=ll_args, ) num_dispatchers = pgi.world_size // dp_size @@ -184,8 +181,8 @@ def deep_ep_moe_impl( test_tensors: TestTensors, w1: torch.Tensor, w2: torch.Tensor, - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], + w1_scale: torch.Tensor | None, + w2_scale: torch.Tensor | None, num_experts: int, use_fp8_dispatch: bool, per_act_token_quant: bool, @@ -281,8 +278,8 @@ def torch_moe_impl( test_tensors: TestTensors, w1: torch.Tensor, w2: torch.Tensor, - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], + w1_scale: torch.Tensor | None, + w2_scale: torch.Tensor | None, using_fp8_dispatch: bool, per_act_token_quant: bool, ): @@ -340,8 +337,8 @@ def _deep_ep_moe( config: TestConfig, w1: torch.Tensor, w2: torch.Tensor, - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], + w1_scale: torch.Tensor | None, + w2_scale: torch.Tensor | None, use_fp8_dispatch: bool, per_act_token_quant: bool, ): diff --git a/tests/kernels/moe/test_gpt_oss_triton_kernels.py b/tests/kernels/moe/test_gpt_oss_triton_kernels.py index f78596d220bf..d4a79a7eff75 100644 --- a/tests/kernels/moe/test_gpt_oss_triton_kernels.py +++ b/tests/kernels/moe/test_gpt_oss_triton_kernels.py @@ -6,7 +6,7 @@ import torch import torch.nn.functional as F -from vllm.utils import has_triton_kernels +from vllm.utils.import_utils import has_triton_kernels if not has_triton_kernels(): pytest.skip( @@ -23,15 +23,9 @@ from triton_kernels.testing import assert_close from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedPrepareAndFinalize, -) -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( - BatchedOAITritonExperts, triton_kernel_moe_forward, ) -from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.utils import shuffle_weight from vllm.utils import round_up @@ -302,8 +296,8 @@ def test_equiv(num_token, a_dtype, w_dtype, tp): quant_config = FusedMoEQuantConfig.make( w1_bias=w1_bias_tri, w2_bias=w2_bias_tri, - w1_precision=pc1, - w2_precision=pc2, + w1_scale=pc1, + w2_scale=pc2, ) out_triton_monolithic = triton_kernel_moe_forward( @@ -329,115 +323,6 @@ def test_equiv(num_token, a_dtype, w_dtype, tp): assert_close(ref=out_ref, tri=out_triton_monolithic, maxtol=0.025, rmstol=0.005) -def batched_moe( - a: torch.Tensor, - w1, - w2, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - w1_bias: torch.Tensor, - w2_bias: torch.Tensor, - w1_precision: PrecisionConfig, - w2_precision: PrecisionConfig, -) -> torch.Tensor: - max_num_tokens = round_up(a.shape[0], 64) - - quant_config = FusedMoEQuantConfig.make( - w1_precision=w1_precision, - w2_precision=w2_precision, - w1_bias=w1_bias, - w2_bias=w2_bias, - ) - - fused_experts = FusedMoEModularKernel( - BatchedPrepareAndFinalize( - max_num_tokens, - num_dispatchers=1, - num_local_experts=w1.shape[0], - rank=0, - ), - BatchedOAITritonExperts( - max_num_tokens=max_num_tokens, - num_dispatchers=1, - quant_config=quant_config, - ), - ) - - topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize) - - return fused_experts( - a, - w1, - w2, - topk_weight, - topk_ids, - ) - - -@pytest.mark.parametrize( - ", ".join(f.name for f in fields(Case)), - [ - tuple(getattr(case, f.name) for f in fields(Case)) - for case in [ - # Case(a_dtype="bf16", w_dtype="bf16"), - # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"), - Case(a_dtype="bf16", w_dtype="mx4") - ] - ], -) -@pytest.mark.parametrize("num_token", [64]) -@pytest.mark.parametrize("ep", [1, 2, 4, 8]) -def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep): - M = num_token - E = ModelConfig.num_experts // ep - K = ModelConfig.hidden_size - N = ModelConfig.intermediate_size - topk = ModelConfig.experts_per_token - - ( - x, - w1, - w1_bias, - w2, - w2_bias, - exp_data, - x_tri, - w1_tri, - w2_tri, - exp_data_tri, - w1_bias_tri, - w2_bias_tri, - pc1, - pc2, - ) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=4) - - out_tri = batched_moe( - a=x_tri, - w1=w1_tri, - w2=w2_tri, - gating_output=exp_data_tri, - topk=topk, - renormalize=True, - w1_bias=w1_bias_tri, - w2_bias=w2_bias_tri, - w1_precision=pc1, - w2_precision=pc2, - ) - out_tri = out_tri[..., :K] - - out_ref = oai_moe_forward( - hidden_states=x, - w1=w1, - w1_bias=w1_bias, - w2=w2, - w2_bias=w2_bias, - gating_output=exp_data, - topk=topk, - ) - assert_close(ref=out_ref, tri=out_tri, maxtol=0.025, rmstol=0.005) - - def test_unit_shuffle(): N = ModelConfig.intermediate_size K = ModelConfig.hidden_size diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index 9c4114523590..a46b0053e75a 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -5,7 +5,7 @@ import textwrap import traceback from itertools import product -from typing import Optional +from typing import Any import pytest import torch @@ -13,10 +13,10 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import VllmConfig, set_current_vllm_config from vllm.platforms import current_platform -from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx +from vllm.utils.torch_utils import cuda_device_count_stateless -from ...utils import multi_gpu_test from .modular_kernel_tools.common import ( Config, RankTensors, @@ -132,7 +132,8 @@ def rank_worker( def run(config: Config, verbose: bool): - assert config.is_valid() + assert config.is_valid()[0] + assert not is_nyi_config(config) weights: WeightTensors = WeightTensors.make(config) @@ -168,31 +169,97 @@ def is_nyi_config(config: Config) -> bool: return not info.supports_expert_map -@pytest.mark.parametrize("k", Ks) -@pytest.mark.parametrize("n", Ns) -@pytest.mark.parametrize("e", Es) -@pytest.mark.parametrize("dtype", DTYPEs) -@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS) +def generate_valid_test_cases( + world_size: int, prepare_finalize_types +) -> list[tuple[Any, ...]]: + cases = [] + total = 0 + + for k, n, e, dtype, quant_config, combination, chunk_size in product( + Ks, + Ns, + Es, + DTYPEs, + MK_QUANT_CONFIGS, + product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES), + FUSED_MOE_CHUNK_SIZEs, + ): + total = total + 1 + + config = Config( + Ms=Ms, + K=k, + N=n, + E=e, + topks=TOPKs, + dtype=dtype, + quant_config=quant_config, + prepare_finalize_type=combination[0], + fused_experts_type=combination[1], + fused_moe_chunk_size=chunk_size, + world_size=world_size, + ) + + # TODO(bnell): figure out how to get verbose flag here. + verbose = False # pytestconfig.getoption('verbose') > 0 + + valid, reason = config.is_valid() + + if not valid: + if verbose: + print(f"Test config {config} is not valid: {reason}") + continue + + if is_nyi_config(config): + if verbose: + print(f"Test config {config} is nyi.") + continue + + cases.append( + ( + k, + n, + e, + dtype, + quant_config, + combination[0], + combination[1], + chunk_size, + world_size, + ) + ) + + print(f"{len(cases)} of {total} valid configs generated.") + + return cases + + @pytest.mark.parametrize( - "combination", product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES) + "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size", + generate_valid_test_cases( + world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES + ), ) -@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) -@pytest.mark.parametrize("world_size", [2]) -@multi_gpu_test(num_gpus=2) @meets_multi_gpu_requirements def test_modular_kernel_combinations_multigpu( k: int, n: int, e: int, dtype: torch.dtype, - quant_config: Optional[TestMoEQuantConfig], - combination: tuple[ - mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute - ], - fused_moe_chunk_size: Optional[int], + quant_config: TestMoEQuantConfig | None, + prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, + fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, + chunk_size: int | None, world_size: int, pytestconfig, ): + if cuda_device_count_stateless() < world_size: + pytest.skip( + f"Not enough GPUs available to run, got " + f"{cuda_device_count_stateless()} exepected " + f"{world_size}." + ) + config = Config( Ms=Ms, K=k, @@ -201,42 +268,30 @@ def test_modular_kernel_combinations_multigpu( topks=TOPKs, dtype=dtype, quant_config=quant_config, - prepare_finalize_type=combination[0], - fused_experts_type=combination[1], - fused_moe_chunk_size=fused_moe_chunk_size, + prepare_finalize_type=prepare_finalize_type, + fused_experts_type=fused_experts_type, + fused_moe_chunk_size=chunk_size, world_size=world_size, ) - - if not config.is_valid(): - pytest.skip(f"Tests config {config} is not valid. Skipping ...") - - if is_nyi_config(config): - pytest.skip(f"Tests config {config} is nyi. Skipping ...") - verbosity = pytestconfig.getoption("verbose") run(config, verbosity > 0) -@pytest.mark.parametrize("k", Ks) -@pytest.mark.parametrize("n", Ns) -@pytest.mark.parametrize("e", Es) -@pytest.mark.parametrize("dtype", DTYPEs) -@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS) @pytest.mark.parametrize( - "combination", product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES) + "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size", + generate_valid_test_cases( + world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES + ), ) -@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) -@pytest.mark.parametrize("world_size", [1]) def test_modular_kernel_combinations_singlegpu( k: int, n: int, e: int, dtype: torch.dtype, - quant_config: Optional[TestMoEQuantConfig], - combination: tuple[ - mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute - ], - fused_moe_chunk_size: Optional[int], + quant_config: TestMoEQuantConfig | None, + prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, + fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, + chunk_size: int | None, world_size: int, pytestconfig, ): @@ -248,18 +303,12 @@ def test_modular_kernel_combinations_singlegpu( topks=TOPKs, dtype=dtype, quant_config=quant_config, - prepare_finalize_type=combination[0], - fused_experts_type=combination[1], - fused_moe_chunk_size=fused_moe_chunk_size, + prepare_finalize_type=prepare_finalize_type, + fused_experts_type=fused_experts_type, + fused_moe_chunk_size=chunk_size, world_size=world_size, ) - if not config.is_valid(): - pytest.skip(f"Tests config {config} is not valid. Skipping ...") - - if is_nyi_config(config): - pytest.skip(f"Tests config {config} is nyi. Skipping ...") - verbosity = pytestconfig.getoption("verbose") run(config, verbosity > 0) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index f357d149bd07..2c802ff4e6bd 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -6,7 +6,9 @@ """ import functools -from typing import Callable, Optional, Union +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any import pytest import torch @@ -26,6 +28,10 @@ int4_w4a16_moe_quant_config, int8_w8a16_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + batched_fused_marlin_moe, + fused_marlin_moe, +) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe, @@ -80,7 +86,7 @@ def run_moe_test( - baseline: Union[Callable, torch.Tensor], + baseline: Callable | torch.Tensor, moe_fn: Callable, a: torch.Tensor, w1: torch.Tensor, @@ -88,7 +94,7 @@ def run_moe_test( score: torch.Tensor, topk: int, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, padding: bool = False, use_compile: bool = False, use_cudagraph: bool = False, @@ -212,7 +218,7 @@ def m_fused_moe( score: torch.Tensor, topk: int, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, ) -> torch.Tensor: topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) return m_fused_moe_fn( @@ -563,6 +569,105 @@ def is_invalid( return cases +@dataclass +class MarlinMoEWeightData: + w_ref: torch.Tensor + qweight: torch.Tensor + scales: torch.Tensor + global_scale: torch.Tensor | None + g_idx: torch.Tensor | None + zeros: torch.Tensor | None + sort_indices: torch.Tensor | None + marlin_bias: torch.Tensor | None + + @staticmethod + def make( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool | None = None, + bias: torch.Tensor | None = None, + ) -> "MarlinMoEWeightData": + assert w.ndim == 3 + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] + k = w.shape[-1] + + w_ref_l: list[torch.Tensor] = [] + qweight_l: list[torch.Tensor] = [] + scales_l: list[torch.Tensor] = [] + global_scale_l: list[torch.Tensor] = [] + zeros_l: list[torch.Tensor] = [] + g_idx_l: list[torch.Tensor] = [] + sort_indices_l: list[torch.Tensor] = [] + bias_l: list[torch.Tensor] = [] + + for i in range(w.shape[0]): + if quant_type == scalar_types.float4_e2m1f: + if group_size == 16: + w_ref, qweight, scales, global_scale = ( + rand_marlin_weight_nvfp4_like(w[i], group_size) + ) + else: + w_ref, qweight, scales = rand_marlin_weight_mxfp4_like( + w[i], group_size + ) + global_scale = None + + w_ref_l.append(w_ref.T) + qweight_l.append(qweight) + scales_l.append(scales) + if global_scale is not None: + global_scale_l.append(global_scale) + elif quant_type == scalar_types.float8_e4m3fn: + w_ref, qweight, scales = marlin_quant_fp8_torch(w[i], group_size) + w_ref_l.append(w_ref.T) + qweight_l.append(qweight) + scales_l.append(scales) + elif has_zp: + w_ref, qweight, scales, zeros = awq_marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size + ) + + w_ref_l.append(w_ref.T) + qweight_l.append(qweight) + scales_l.append(scales) + zeros_l.append(zeros) + else: + test_perm = torch.randperm(k) + w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) + + w_ref_l.append(w_ref.T) + qweight_l.append(qweight) + scales_l.append(scales) + g_idx_l.append(g_idx) + sort_indices_l.append(sort_indices) + + if bias is not None: + bias_l.append(marlin_permute_bias(bias[i])) + + w_ref = stack_and_dev(w_ref_l) + qweight = stack_and_dev(qweight_l).contiguous() + scales = stack_and_dev(scales_l) + global_scale = stack_and_dev(global_scale_l) if global_scale_l else None + g_idx = stack_and_dev(g_idx_l) if g_idx_l else None + zeros = stack_and_dev(zeros_l) if zeros_l else None + sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None + marlin_bias = stack_and_dev(bias_l) if bias_l else None + + return MarlinMoEWeightData( + w_ref=w_ref, + qweight=qweight, + scales=scales, + global_scale=global_scale, + g_idx=g_idx, + zeros=zeros, + sort_indices=sort_indices, + marlin_bias=marlin_bias, + ) + + @pytest.mark.flaky(reruns=2) @pytest.mark.parametrize( ("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"), @@ -583,7 +688,6 @@ def test_fused_marlin_moe( is_k_full: bool, ): torch.cuda.manual_seed(0) - has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 @@ -599,152 +703,44 @@ def test_fused_marlin_moe( else: e_map = None - w_ref1_l = [] - qweight1_l = [] - scales1_l = [] - global_scale1_l = [] - zeros1_l = [] - g_idx1_l = [] - sort_indices1_l = [] - - for i in range(w1.shape[0]): - if quant_type == scalar_types.float4_e2m1f: - if group_size == 16: - w_ref1, qweight1, scales1, global_scale1 = ( - rand_marlin_weight_nvfp4_like(w1[i], group_size) - ) - else: - w_ref1, qweight1, scales1 = rand_marlin_weight_mxfp4_like( - w1[i], group_size - ) - global_scale1 = None - - w_ref1_l.append(w_ref1.T) - qweight1_l.append(qweight1) - scales1_l.append(scales1) - if global_scale1 is not None: - global_scale1_l.append(global_scale1) - elif quant_type == scalar_types.float8_e4m3fn: - w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(w1[i], group_size) - w_ref1_l.append(w_ref1.T) - qweight1_l.append(qweight1) - scales1_l.append(scales1) - elif has_zp: - w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize( - w1[i].transpose(1, 0), quant_type, group_size - ) - - w_ref1_l.append(w_ref1.T) - qweight1_l.append(qweight1) - scales1_l.append(scales1) - zeros1_l.append(zeros1) - else: - test_perm = torch.randperm(k) - w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( - w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm - ) - - w_ref1_l.append(w_ref1.T) - qweight1_l.append(qweight1) - scales1_l.append(scales1) - g_idx1_l.append(g_idx1) - sort_indices1_l.append(sort_indices1) - - w_ref1 = stack_and_dev(w_ref1_l) - qweight1 = stack_and_dev(qweight1_l).contiguous() - scales1 = stack_and_dev(scales1_l) - global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None - g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None - zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None - sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None - - w_ref2_l = [] - qweight2_l = [] - scales2_l = [] - global_scale2_l = [] - zeros2_l = [] - g_idx2_l = [] - sort_indices2_l = [] - - for i in range(w2.shape[0]): - if quant_type == scalar_types.float4_e2m1f: - if group_size == 16: - w_ref2, qweight2, scales2, global_scale2 = ( - rand_marlin_weight_nvfp4_like(w2[i], group_size) - ) - else: - w_ref2, qweight2, scales2 = rand_marlin_weight_mxfp4_like( - w2[i], group_size - ) - global_scale2 = None - - w_ref2_l.append(w_ref2.T) - qweight2_l.append(qweight2) - scales2_l.append(scales2) - if global_scale2 is not None: - global_scale2_l.append(global_scale2) - elif quant_type == scalar_types.float8_e4m3fn: - w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(w2[i], group_size) - w_ref2_l.append(w_ref2.T) - qweight2_l.append(qweight2) - scales2_l.append(scales2) - elif has_zp: - w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize( - w2[i].transpose(1, 0), quant_type, group_size - ) - - w_ref2_l.append(w_ref2.T) - qweight2_l.append(qweight2) - scales2_l.append(scales2) - zeros2_l.append(zeros2) - else: - test_perm = torch.randperm(n) - w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( - w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm - ) - - w_ref2_l.append(w_ref2.T) - qweight2_l.append(qweight2) - scales2_l.append(scales2) - g_idx2_l.append(g_idx2) - sort_indices2_l.append(sort_indices2) + w1_data = MarlinMoEWeightData.make( + w=w1, quant_type=quant_type, group_size=group_size, act_order=act_order + ) - w_ref2 = stack_and_dev(w_ref2_l) - qweight2 = stack_and_dev(qweight2_l).contiguous() - scales2 = stack_and_dev(scales2_l) - global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None - g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None - zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None - sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None + w2_data = MarlinMoEWeightData.make( + w=w2, quant_type=quant_type, group_size=group_size, act_order=act_order + ) score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map) + torch_output = torch_moe( + a, w1_data.w_ref, w2_data.w_ref, score, topk, expert_map=e_map + ) - marlin_output = torch.ops.vllm.fused_marlin_moe( + marlin_output = fused_marlin_moe( a, - qweight1, - qweight2, + w1_data.qweight, + w2_data.qweight, None, None, - scales1, - scales2, + w1_data.scales, + w2_data.scales, score, topk_weights, topk_ids, global_num_experts=e, expert_map=e_map, - global_scale1=global_scale1, - global_scale2=global_scale2, - g_idx1=g_idx1, - g_idx2=g_idx2, - sort_indices1=sort_indices1, - sort_indices2=sort_indices2, - w1_zeros=zeros1, - w2_zeros=zeros2, + global_scale1=w1_data.global_scale, + global_scale2=w2_data.global_scale, + g_idx1=w1_data.g_idx, + g_idx2=w2_data.g_idx, + sort_indices1=w1_data.sort_indices, + sort_indices2=w2_data.sort_indices, + w1_zeros=w1_data.zeros, + w2_zeros=w2_data.zeros, quant_type_id=quant_type.id, is_k_full=is_k_full, ) @@ -772,92 +768,52 @@ def test_fused_marlin_moe_with_bias(m): b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10 b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10 - b_bias1_l = [] - w_ref1_l = [] - qweight1_l = [] - scales1_l = [] - g_idx1_l = [] - sort_indices1_l = [] - - for i in range(w1.shape[0]): - test_perm = torch.randperm(k) - w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( - w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm - ) - - w_ref1_l.append(w_ref1.T) - qweight1_l.append(qweight1) - scales1_l.append(scales1) - g_idx1_l.append(g_idx1) - sort_indices1_l.append(sort_indices1) - b_bias1_l.append(marlin_permute_bias(b_bias1[i])) - - w_ref1 = stack_and_dev(w_ref1_l) - qweight1 = stack_and_dev(qweight1_l).contiguous() - scales1 = stack_and_dev(scales1_l) - global_scale1 = None - g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None - zeros1 = None - sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None - marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None - - b_bias2_l = [] - w_ref2_l = [] - qweight2_l = [] - scales2_l = [] - g_idx2_l = [] - sort_indices2_l = [] - - for i in range(w2.shape[0]): - test_perm = torch.randperm(n) - w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( - w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm - ) + w1_data = MarlinMoEWeightData.make( + w=w1, + quant_type=quant_type, + group_size=group_size, + act_order=act_order, + bias=b_bias1, + ) - w_ref2_l.append(w_ref2.T) - qweight2_l.append(qweight2) - scales2_l.append(scales2) - g_idx2_l.append(g_idx2) - sort_indices2_l.append(sort_indices2) - b_bias2_l.append(marlin_permute_bias(b_bias2[i])) - - w_ref2 = stack_and_dev(w_ref2_l) - qweight2 = stack_and_dev(qweight2_l).contiguous() - scales2 = stack_and_dev(scales2_l) - global_scale2 = None - g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None - zeros2 = None - sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None - marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None + w2_data = MarlinMoEWeightData.make( + w=w2, + quant_type=quant_type, + group_size=group_size, + act_order=act_order, + bias=b_bias2, + ) score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2) + torch_output = torch_moe( + a, w1_data.w_ref, w2_data.w_ref, score, topk, b_bias1, b_bias2 + ) - marlin_output = torch.ops.vllm.fused_marlin_moe( + marlin_output = fused_marlin_moe( a, - qweight1, - qweight2, - marlin_bias1, - marlin_bias2, - scales1, - scales2, + w1_data.qweight, + w2_data.qweight, + w1_data.marlin_bias, + w2_data.marlin_bias, + w1_data.scales, + w2_data.scales, score, topk_weights, topk_ids, global_num_experts=e, expert_map=None, - global_scale1=global_scale1, - global_scale2=global_scale2, - g_idx1=g_idx1, - g_idx2=g_idx2, - sort_indices1=sort_indices1, - sort_indices2=sort_indices2, - w1_zeros=zeros1, - w2_zeros=zeros2, + global_scale1=w1_data.global_scale, + global_scale2=w2_data.global_scale, + g_idx1=w1_data.g_idx, + g_idx2=w2_data.g_idx, + sort_indices1=w1_data.sort_indices, + sort_indices2=w2_data.sort_indices, + w1_zeros=w1_data.zeros, + w2_zeros=w2_data.zeros, quant_type_id=quant_type.id, is_k_full=is_k_full, ) @@ -894,6 +850,41 @@ def test_moe_align_block_size_opcheck(): ) +def test_batched_moe_align_block_size_opcheck(): + max_tokens_per_batch = 512 + num_experts = 4 + block_size = 16 + + expert_num_tokens = torch.randint( + low=0, + high=max_tokens_per_batch, + size=(num_experts,), + dtype=torch.int32, + device="cuda", + ) + + max_num_tokens_padded = num_experts * max(max_tokens_per_batch, block_size) + sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") + + assert max_num_tokens_padded % block_size == 0 + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda") + + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") + + opcheck( + torch.ops._moe_C.batched_moe_align_block_size, + ( + max_tokens_per_batch, + block_size, + expert_num_tokens, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ), + ) + + @pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("k", [128, 511, 1024]) @@ -978,3 +969,171 @@ def __init__(self, w13, w2, b1=None, b2=None): else: atol = 5e-2 torch.testing.assert_close(out, ref, atol=atol, rtol=0) + + +@pytest.mark.parametrize("m", [16, 32, 64]) +@pytest.mark.parametrize("n", [128]) +@pytest.mark.parametrize("k", [128]) +@pytest.mark.parametrize("e", [8, 12, 16, 32]) +@pytest.mark.parametrize("topk", [2, 4]) +@pytest.mark.parametrize("max_tokens_per_batch", [16, 32, 64]) +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") +def test_batched_fused_marlin_moe( + m: int, n: int, k: int, e: int, topk: int, max_tokens_per_batch: int +): + print( + f"testing m={m}, n={n}, k={k}, e={e}, " + f"topk={topk}, " + f"max_tokens_per_batch={max_tokens_per_batch}" + ) + torch.cuda.manual_seed(0) + + dtype = torch.bfloat16 + quant_dtype = scalar_types.float4_e2m1f + group_size = 32 + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20 + + w1_data = MarlinMoEWeightData.make( + w=w1, quant_type=quant_dtype, group_size=group_size, act_order=None + ) + w2_data = MarlinMoEWeightData.make( + w=w2, quant_type=quant_dtype, group_size=group_size, act_order=None + ) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) + + class BatchedRun: + @staticmethod + def _make_expert_num_tokens_cpu( + e: int, # num_experts + topk_ids_cpu: torch.Tensor, + ) -> torch.Tensor: + expert_num_tokens_cpu = torch.zeros((e,), dtype=torch.int32, device="cpu") + for topk_id in torch.flatten(topk_ids_cpu): + expert_num_tokens_cpu[topk_id] += 1 + return expert_num_tokens_cpu + + def __init__( + self, + max_tokens_per_batch: int, + num_experts: int, + _topk_ids: torch.Tensor, + _topk_weights: torch.Tensor, + ): + self.max_tokens_per_batch = max_tokens_per_batch + self.e = num_experts + self.topk_ids_cpu = _topk_ids.to("cpu") + self.topk_weights_cpu = _topk_weights.to("cpu") + self.expert_num_tokens_cpu = self._make_expert_num_tokens_cpu( + self.e, self.topk_ids_cpu + ) + + def is_valid(self): + """ + Return True only if the input can be represented in a Batched + format. + """ + return torch.all(self.expert_num_tokens_cpu <= self.max_tokens_per_batch) + + def _scatter(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states_cpu = hidden_states.to("cpu") + K = hidden_states_cpu.size(1) + batched_hidden_states_cpu = torch.empty( + (e, max_tokens_per_batch, K), + dtype=hidden_states_cpu.dtype, + device="cpu", + ) + + counter_cpu = torch.zeros_like(self.expert_num_tokens_cpu) + for t_idx, token in enumerate(hidden_states_cpu): + for topk_id in self.topk_ids_cpu[t_idx]: + pos_in_batch = counter_cpu[topk_id] + batched_hidden_states_cpu[topk_id, pos_in_batch] = token + counter_cpu[topk_id] += 1 + assert torch.allclose(counter_cpu, self.expert_num_tokens_cpu) + return batched_hidden_states_cpu.to("cuda") + + def _gather( + self, batched_outputs: torch.Tensor, gather_outputs: torch.Tensor + ) -> torch.Tensor: + batched_outputs_cpu = batched_outputs.to("cpu") + gather_outputs_cpu = torch.zeros_like(gather_outputs) + + counter_cpu = torch.zeros((e,), device="cpu", dtype=torch.int32) + md = gather_outputs_cpu.size(0) + for t_idx in range(md): + token = None + for topk_id, topk_weight in zip( + self.topk_ids_cpu[t_idx], self.topk_weights_cpu[t_idx] + ): + pos_in_batch = counter_cpu[topk_id] + t = batched_outputs_cpu[topk_id, pos_in_batch] * topk_weight + if token is None: + token = t + else: + token += t + counter_cpu[topk_id] += 1 + assert token is not None + gather_outputs_cpu[t_idx] = token + gather_outputs.copy_(gather_outputs_cpu) + return gather_outputs + + def run( + self, hidden_states: torch.Tensor, fused_marlin_moe_kwargs: dict[Any, Any] + ) -> torch.Tensor: + assert hidden_states.ndim == 2 + assert self.is_valid() + + batched_hidden_states = self._scatter(hidden_states) + + kwargs = fused_marlin_moe_kwargs | { + "hidden_states": batched_hidden_states, + "expert_num_tokens": self.expert_num_tokens_cpu.to("cuda"), + } + batched_outputs = batched_fused_marlin_moe(**kwargs) + + output = torch.zeros_like(hidden_states) + output = self._gather(batched_outputs, output) + return output + + kwargs = { + "w1": w1_data.qweight, + "w2": w2_data.qweight, + "bias1": None, + "bias2": None, + "w1_scale": w1_data.scales, + "w2_scale": w2_data.scales, + "gating_output": score, + "global_num_experts": e, + "expert_map": None, + "global_scale1": w1_data.global_scale, + "global_scale2": w2_data.global_scale, + "g_idx1": w1_data.g_idx, + "g_idx2": w2_data.g_idx, + "sort_indices1": w1_data.sort_indices, + "sort_indices2": w2_data.sort_indices, + "w1_zeros": w1_data.zeros, + "w2_zeros": w2_data.zeros, + "quant_type_id": quant_dtype.id, + "is_k_full": True, + } + + # Reference + fused_marlin_moe_kwargs = kwargs | { + "hidden_states": a, + "topk_ids": topk_ids, + "topk_weights": topk_weights, + } + ref_marlin_output = fused_marlin_moe(**fused_marlin_moe_kwargs) + + # Batched + br = BatchedRun(max_tokens_per_batch, e, topk_ids, topk_weights) + if not br.is_valid(): + pytest.skip("Cannot represent data in Batched Format.") + marlin_output = br.run(a, kwargs) + + torch.testing.assert_close(marlin_output, ref_marlin_output, atol=1e-3, rtol=0) diff --git a/tests/kernels/moe/test_moe_align_block_size.py b/tests/kernels/moe/test_moe_align_block_size.py index f92526e74955..bde0478d9c18 100644 --- a/tests/kernels/moe/test_moe_align_block_size.py +++ b/tests/kernels/moe/test_moe_align_block_size.py @@ -5,12 +5,11 @@ Run `pytest tests/kernels/moe/test_moe_align_block_size.py`. """ -from typing import Optional - import pytest import torch from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + batched_moe_align_block_size, moe_align_block_size, ) from vllm.platforms import current_platform @@ -94,7 +93,7 @@ def torch_moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int, - expert_map: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, pad_sorted_ids: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -302,3 +301,96 @@ def test_moe_align_block_size_deterministic(): assert torch.equal(results[0][2], results[i][2]), ( "num_tokens should be deterministic" ) + + +@pytest.mark.parametrize("max_tokens_per_batch", [13, 16, 512]) +@pytest.mark.parametrize("num_experts", [8, 16, 32, 64]) +@pytest.mark.parametrize("block_size", [8, 16, 32, 64]) +@pytest.mark.parametrize("simulate_empty_batches", [False, True]) +def test_batched_moe_align_block_size( + max_tokens_per_batch: int, + num_experts: int, + block_size: int, + simulate_empty_batches: bool, +): + def ref_outputs( + expert_num_tokens: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + E = expert_num_tokens.size(0) + + # Round up so each batch can be split to blocks evenly. + Msum = round_up(max_tokens_per_batch, block_size) * E + ref_sorted_ids = torch.empty((Msum,), dtype=torch.int32) + ref_expert_ids = torch.empty((Msum // block_size,), dtype=torch.int32) + ref_num_tokens_post_pad = torch.empty((1,), dtype=torch.int32) + + # Intialize + sentinel = E * max_tokens_per_batch + ref_sorted_ids.fill_(sentinel) + ref_expert_ids.fill_(-1) + + # Fill ref_sorted_ids + i = 0 + for expert_id, expert_nt in enumerate(expert_num_tokens): + token_offset = expert_id * max_tokens_per_batch + for j in range(expert_nt): + ref_sorted_ids[i] = token_offset + j + i += 1 + # round up i to the next block_size + i = round_up(i, block_size) + + ref_num_tokens_post_pad[0] = i + + # Fill expert_ids + nt_ceil_sum = 0 + for expert_id, expert_nt in enumerate(expert_num_tokens): + expert_ids_offset = nt_ceil_sum // block_size + ceil_expert_nt = round_up(int(expert_nt.item()), block_size) + num_blocks = ceil_expert_nt // block_size + for x in range(num_blocks): + ref_expert_ids[expert_ids_offset + x] = expert_id + nt_ceil_sum += ceil_expert_nt + + return ( + ref_sorted_ids.to("cuda"), + ref_expert_ids.to("cuda"), + ref_num_tokens_post_pad.to("cuda"), + ) + + # Compute expert_num_tokens + expert_num_tokens = torch.randint( + low=0, + high=max_tokens_per_batch, + size=(num_experts,), + device="cpu", + dtype=torch.int32, + ) + if simulate_empty_batches: + # mark half the batches to have 0 tokens + zero_batches = torch.randperm(num_experts)[: num_experts // 2] + expert_num_tokens[zero_batches] = 0 + + # ref outputs + ref_sorted_ids, ref_expert_ids, ref_num_tokens_post_pad = ref_outputs( + expert_num_tokens + ) + + # outputs + sorted_ids, expert_ids, num_tokens_post_pad = batched_moe_align_block_size( + max_tokens_per_batch, block_size, expert_num_tokens.to("cuda") + ) + + assert ref_sorted_ids.size() == sorted_ids.size(), ( + f"{ref_sorted_ids.size()} vs {sorted_ids.size()}" + ) + assert ref_expert_ids.size() == expert_ids.size(), ( + f"{ref_expert_ids.size()} vs {expert_ids.size()}" + ) + assert ref_num_tokens_post_pad.size() == num_tokens_post_pad.size(), ( + f"{ref_num_tokens_post_pad.size()} vs {num_tokens_post_pad.size()}" + ) + torch.testing.assert_close(ref_sorted_ids, sorted_ids, atol=0, rtol=0) + torch.testing.assert_close(ref_expert_ids, expert_ids, atol=0, rtol=0) + torch.testing.assert_close( + ref_num_tokens_post_pad, num_tokens_post_pad, atol=0, rtol=0 + ) diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index a6214437d404..ba1f657b3ecd 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -5,8 +5,6 @@ Run `pytest tests/kernels/test_moe_permute_unpermute.py`. """ -from typing import Optional - import numpy as np import pytest import torch @@ -34,8 +32,8 @@ def torch_permute( n_expert: int, n_local_expert: int, start_expert: int, - expert_map: Optional[torch.Tensor] = None, - align_block_size: Optional[int] = None, + expert_map: torch.Tensor | None = None, + align_block_size: int | None = None, fill_invalid_expert: int = -1, ) -> list[torch.Tensor]: n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1] @@ -210,7 +208,7 @@ def test_moe_permute_unpermute( n_expert: int, ep_size: int, dtype: torch.dtype, - align_block_size: Optional[int], + align_block_size: int | None, ): if not moe_permute_unpermute_supported(): pytest.skip("moe_permute_unpermute is not supported on this platform.") @@ -219,7 +217,7 @@ def test_moe_permute_unpermute( expert_map = None n_local_expert = n_expert if ep_size != 1: - n_local_expert, expert_map = determine_expert_map(ep_size, ep_rank, n_expert) + n_local_expert, expert_map, _ = determine_expert_map(ep_size, ep_rank, n_expert) expert_map = expert_map.cuda() start_expert = n_local_expert * ep_rank current_platform.seed_everything(0) diff --git a/tests/kernels/moe/test_ocp_mx_moe.py b/tests/kernels/moe/test_ocp_mx_moe.py index dceed34f3512..91b508d4163c 100644 --- a/tests/kernels/moe/test_ocp_mx_moe.py +++ b/tests/kernels/moe/test_ocp_mx_moe.py @@ -4,7 +4,6 @@ import importlib.metadata from dataclasses import dataclass from importlib.util import find_spec -from typing import Optional import pytest import torch @@ -38,7 +37,7 @@ trtllm_fp4_block_scale_moe, ) from flashinfer.fp4_quantization import nvfp4_block_scale_interleave - from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices + from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache @dataclass @@ -103,7 +102,7 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): assert output -def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: Optional[float] = None): +def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: float | None = None): # Note we add an extra bias of 1 to the linear layer x_glu, x_linear = torch.chunk(x, 2, dim=-1) if limit is not None: @@ -320,7 +319,7 @@ def tg_mxfp4_moe( if transpose_optimized: for i in range(num_experts): # w13 weight shuffling - permute_indices = _maybe_get_cached_w2_permute_indices( + permute_indices = get_w2_permute_indices_with_cache( _cache_permute_indices, w13_weight[i].view(torch.uint8), epilogue_tile_m, @@ -331,7 +330,7 @@ def tg_mxfp4_moe( .contiguous() ) # w13 scale shuffling - permute_sf_indices = _maybe_get_cached_w2_permute_indices( + permute_sf_indices = get_w2_permute_indices_with_cache( _cache_permute_indices, w13_weight_scale[i].view(torch.uint8), epilogue_tile_m, @@ -345,7 +344,7 @@ def tg_mxfp4_moe( ) ) # w13 bias shuffling - permute_bias_indices = _maybe_get_cached_w2_permute_indices( + permute_bias_indices = get_w2_permute_indices_with_cache( _cache_permute_indices, w13_bias[i].clone().reshape(-1, 1), epilogue_tile_m, @@ -357,7 +356,7 @@ def tg_mxfp4_moe( .contiguous() ) # w2 weight shuffling - permute_indices = _maybe_get_cached_w2_permute_indices( + permute_indices = get_w2_permute_indices_with_cache( _cache_permute_indices, w2_weight[i].view(torch.uint8), epilogue_tile_m, @@ -368,7 +367,7 @@ def tg_mxfp4_moe( .contiguous() ) # w2 scale shuffling - permute_sf_indices = _maybe_get_cached_w2_permute_indices( + permute_sf_indices = get_w2_permute_indices_with_cache( _cache_permute_indices, w2_weight_scale[i].view(torch.uint8), epilogue_tile_m, @@ -382,7 +381,7 @@ def tg_mxfp4_moe( ) ) # w2 bias shuffling - permute_indices = _maybe_get_cached_w2_permute_indices( + permute_indices = get_w2_permute_indices_with_cache( _cache_permute_indices, w2_bias[i].clone().reshape(-1, 1), epilogue_tile_m, @@ -510,7 +509,7 @@ def test_trtllm_gen_mxfp4_fused_moe( hidden_size: int, alpha: float, beta: float, - limit: Optional[float], + limit: float | None, act_type: str, transpose_optimized: bool, ): @@ -660,7 +659,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe( hidden_size: int, alpha: float, beta: float, - limit: Optional[float], + limit: float | None, ): torch.manual_seed(42) device = "cuda:0" @@ -811,9 +810,9 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe( num_tokens: int, intermediate_size: int, hidden_size: int, - alpha: Optional[float], - beta: Optional[float], - limit: Optional[float], + alpha: float | None, + beta: float | None, + limit: float | None, ): torch.manual_seed(42) device = "cuda:0" diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index 4c7c6c6a4f52..ac7f3fc5e6f0 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import pytest import torch @@ -73,7 +72,7 @@ def pplx_cutlass_moe( out_dtype, per_act_token: bool, per_out_ch: bool, - group_name: Optional[str], + group_name: str | None, ): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( PplxPrepareAndFinalize, diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 223f095c0b55..e665c636fa26 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -9,7 +9,7 @@ import itertools import textwrap import traceback -from typing import Callable, Optional, Union +from collections.abc import Callable import pytest import torch @@ -89,7 +89,7 @@ def torch_prepare( a: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, - max_num_tokens: Optional[int] = None, + max_num_tokens: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: assert topk_ids.dim() == 2 assert topk_ids.shape[0] == a.shape[0] @@ -214,10 +214,10 @@ def create_pplx_prepare_finalize( dp_size: int, world_size: int, in_dtype: torch.dtype, - quant_dtype: Optional[torch.dtype], - block_shape: Optional[list[int]], + quant_dtype: torch.dtype | None, + block_shape: list[int] | None, per_act_token_quant: bool, - group_name: Optional[str], + group_name: str | None, ): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( PplxPrepareAndFinalize, @@ -274,18 +274,14 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor: return t[(r * chunk) : (r + 1) * chunk] -def maybe_chunk_by_rank( - t: Optional[torch.Tensor], r: int, w: int -) -> Optional[torch.Tensor]: +def maybe_chunk_by_rank(t: torch.Tensor | None, r: int, w: int) -> torch.Tensor | None: if t is not None: return chunk_by_rank(t, r, w) else: return t -def chunk_scales_by_rank( - t: Optional[torch.Tensor], r: int, w: int -) -> Optional[torch.Tensor]: +def chunk_scales_by_rank(t: torch.Tensor | None, r: int, w: int) -> torch.Tensor | None: if t is not None and t.numel() > 1: chunk = rank_chunk(t.shape[0], r, w) return t[(r * chunk) : (r + 1) * chunk] @@ -293,9 +289,7 @@ def chunk_scales_by_rank( return t -def chunk_scales( - t: Optional[torch.Tensor], start: int, end: int -) -> Optional[torch.Tensor]: +def chunk_scales(t: torch.Tensor | None, start: int, end: int) -> torch.Tensor | None: if t is not None and t.numel() > 1: return t[start:end] else: @@ -313,10 +307,10 @@ def pplx_prepare_finalize( topk_weight: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, - quant_dtype: Optional[torch.dtype], - block_shape: Optional[list[int]], + quant_dtype: torch.dtype | None, + block_shape: list[int] | None, per_act_token_quant: bool, - group_name: Optional[str], + group_name: str | None, ) -> torch.Tensor: assert torch.cuda.current_device() == pgi.local_rank @@ -409,8 +403,8 @@ def _pplx_prepare_finalize( score: torch.Tensor, topk: torch.Tensor, num_experts: int, - quant_dtype: Optional[torch.dtype], - block_shape: Optional[list[int]], + quant_dtype: torch.dtype | None, + block_shape: list[int] | None, per_act_token_quant: bool, use_internode: bool, ): @@ -479,7 +473,7 @@ def test_pplx_prepare_finalize_slow( dtype: torch.dtype, world_dp_size: tuple[int, int], per_act_token_quant: bool, - block_shape: Optional[list[int]], + block_shape: list[int] | None, use_internode: bool, ): if dtype == torch.float8_e4m3fn: @@ -521,7 +515,7 @@ def test_pplx_prepare_finalize_slow( def pplx_moe( - group_name: Optional[str], + group_name: str | None, rank: int, world_size: int, dp_size: int, @@ -530,17 +524,17 @@ def pplx_moe( w2: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - quant_dtype: Optional[torch.dtype] = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + quant_dtype: torch.dtype | None = None, per_act_token_quant=False, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, use_compile: bool = False, use_cudagraphs: bool = True, - shared_experts: Optional[torch.nn.Module] = None, -) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + shared_experts: torch.nn.Module | None = None, +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: num_tokens, hidden_dim = a.shape num_experts = w1.shape[0] topk = topk_ids.shape[1] @@ -657,13 +651,13 @@ def _pplx_moe( score: torch.Tensor, topk: int, num_experts: int, - w1_s: Optional[torch.Tensor] = None, - w2_s: Optional[torch.Tensor] = None, - quant_dtype: Optional[torch.dtype] = None, + w1_s: torch.Tensor | None = None, + w2_s: torch.Tensor | None = None, + quant_dtype: torch.dtype | None = None, per_act_token_quant: bool = False, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, use_internode: bool = False, - shared_experts: Optional[torch.nn.Module] = None, + shared_experts: torch.nn.Module | None = None, ): try: if use_internode: @@ -812,7 +806,7 @@ def test_pplx_moe_slow( dtype: torch.dtype, world_dp_size: tuple[int, int], per_act_token_quant: bool, - block_shape: Optional[list[int]], + block_shape: list[int] | None, use_internode: bool, ): current_platform.seed_everything(7) diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index b6ca80e97e91..8b3bebb391f2 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -5,7 +5,7 @@ import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - silu_mul_fp8_quant_deep_gemm_cuda, + persistent_masked_m_silu_mul_quant, ) from vllm.platforms import current_platform from vllm.utils import cdiv @@ -50,15 +50,15 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): # Input tensor of shape (E, T, 2*H) y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") tokens_per_expert = torch.randint( - low=T // 2, + low=0, high=T, size=(E,), dtype=torch.int32, device="cuda", ) - # Run the Triton kernel - y_q, y_s = silu_mul_fp8_quant_deep_gemm_cuda( + # Run the SiLU V2 kernel + y_q, y_s = persistent_masked_m_silu_mul_quant( y, tokens_per_expert, group_size=group_size ) @@ -115,10 +115,11 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): y_se = y_s[e].float() y_qe = y_q[e].float() - torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2) torch.testing.assert_close( y_qe[:nt].to(torch.float32), ref_q[:nt].to(torch.float32), atol=2, rtol=2e-1, ) + + torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2) diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index 9466dacb0c11..65ce4073ad5b 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union import torch @@ -27,13 +26,13 @@ def triton_moe( w2: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - quant_dtype: Optional[torch.dtype] = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + quant_dtype: torch.dtype | None = None, per_act_token_quant=False, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, ) -> torch.Tensor: quant_config = FusedMoEQuantConfig.make( quant_dtype, @@ -54,13 +53,13 @@ def batched_moe( w2: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - quant_dtype: Optional[torch.dtype] = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + quant_dtype: torch.dtype | None = None, per_act_token_quant: bool = False, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, ) -> torch.Tensor: max_num_tokens = round_up(a.shape[0], 64) @@ -94,13 +93,13 @@ def naive_batched_moe( w2: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - quant_dtype: Optional[torch.dtype] = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + quant_dtype: torch.dtype | None = None, per_act_token_quant: bool = False, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, ) -> torch.Tensor: max_num_tokens = round_up(a.shape[0], 64) @@ -129,8 +128,8 @@ def naive_batched_moe( def chunk_scales( - scales: Optional[torch.Tensor], start: int, end: int -) -> Optional[torch.Tensor]: + scales: torch.Tensor | None, start: int, end: int +) -> torch.Tensor | None: if scales is not None: if scales.numel() == 1: return scales @@ -144,10 +143,10 @@ def make_quantized_test_activations( m: int, k: int, in_dtype: torch.dtype, - quant_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None, + quant_dtype: torch.dtype | None = None, + block_shape: list[int] | None = None, per_act_token_quant: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10 a_q = a a_scale = None @@ -172,11 +171,11 @@ def make_quantized_test_activations( def moe_quantize_weights( w: torch.Tensor, - w_s: Optional[torch.Tensor], - quant_dtype: Union[torch.dtype, str, None], + w_s: torch.Tensor | None, + quant_dtype: torch.dtype | str | None, per_token_quant: bool, - block_shape: Optional[list[int]], -) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + block_shape: list[int] | None, +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: assert ( quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8 @@ -220,10 +219,10 @@ def make_test_weight( rows: int, cols: int, in_dtype: torch.dtype = torch.bfloat16, - quant_dtype: Union[torch.dtype, str, None] = None, - block_shape: Optional[list[int]] = None, + quant_dtype: torch.dtype | str | None = None, + block_shape: list[int] | None = None, per_out_ch_quant: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15 w_gs = None @@ -262,12 +261,12 @@ def make_test_weights( n: int, k: int, in_dtype: torch.dtype = torch.bfloat16, - quant_dtype: Union[torch.dtype, str, None] = None, - block_shape: Optional[list[int]] = None, + quant_dtype: torch.dtype | str | None = None, + block_shape: list[int] | None = None, per_out_ch_quant: bool = False, ) -> tuple[ - tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]], - tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]], + tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None], + tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None], ]: return ( make_test_weight( @@ -295,9 +294,9 @@ def make_test_quant_config( n: int, k: int, in_dtype: torch.dtype, - quant_dtype: Union[torch.dtype, str, None] = None, + quant_dtype: torch.dtype | str | None = None, per_act_token_quant: bool = False, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, ) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]: (_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights( e, @@ -310,8 +309,8 @@ def make_test_quant_config( ) # Hacky/trivial scales for nvfp4. - a1_gscale: Optional[torch.Tensor] = None - a2_gscale: Optional[torch.Tensor] = None + a1_gscale: torch.Tensor | None = None + a2_gscale: torch.Tensor | None = None if quant_dtype == "nvfp4": a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32) a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32) @@ -348,9 +347,9 @@ def fused_moe( score: torch.Tensor, topk: int, renormalize: bool = False, - quant_config: Optional[FusedMoEQuantConfig] = None, + quant_config: FusedMoEQuantConfig | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, ) -> torch.Tensor: topk_weights, topk_ids, _ = fused_topk( hidden_states, score.float(), topk, renormalize @@ -378,7 +377,7 @@ def __init__( self.b = b.to(dtype=torch.float32) self.out_dtype = out_dtype - def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None @@ -422,8 +421,8 @@ def __init__( quant_config=None, reduce_results: bool = True, prefix: str = "", - w1_s: Optional[torch.Tensor] = None, - w2_s: Optional[torch.Tensor] = None, + w1_s: torch.Tensor | None = None, + w2_s: torch.Tensor | None = None, ) -> None: from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -481,7 +480,7 @@ def make_shared_experts( N: int, K: int, in_dtype: torch.dtype = torch.bfloat16, - quant_dtype: Union[torch.dtype, str, None] = None, + quant_dtype: torch.dtype | str | None = None, ) -> torch.nn.Module: from vllm.model_executor.layers.quantization.fp8 import Fp8Config diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index d892f2a5acc0..34ce91585520 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union import torch @@ -15,13 +14,13 @@ FP8_DTYPE = current_platform.fp8_dtype() -def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: +def as_float32_tensor(x: float | torch.Tensor) -> torch.Tensor: return torch.as_tensor(x, dtype=torch.float32, device="cuda") def ref_dynamic_per_token_quant( - x: torch.tensor, quant_dtype: torch.dtype, scale_ub: Optional[torch.tensor] = None -) -> tuple[torch.tensor, torch.tensor]: + x: torch.Tensor, quant_dtype: torch.dtype, scale_ub: torch.Tensor | None = None +) -> tuple[torch.Tensor, torch.Tensor]: assert quant_dtype in [torch.int8, FP8_DTYPE] if scale_ub is not None: assert quant_dtype == FP8_DTYPE @@ -76,8 +75,8 @@ def ref_dynamic_per_token_quant( # ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant # kernel def ref_dynamic_per_tensor_fp8_quant( - x: torch.tensor, -) -> tuple[torch.tensor, torch.tensor]: + x: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: fp8_traits = torch.finfo(FP8_DTYPE) fp8_traits_max = ( ROCM_FP8FNUZ_MAX @@ -104,7 +103,7 @@ def ref_dynamic_per_tensor_fp8_quant( .clamp(fp8_traits_min, fp8_traits_max) .to(FP8_DTYPE) ) - return ref_out, ref_scale.view((1,)) + return ref_out, ref_scale.view((1, 1)) def native_w8a8_block_matmul( @@ -250,10 +249,10 @@ def per_block_cast_to_int8( def dequant( t: torch.Tensor, - scale: Optional[torch.Tensor], - block_shape: Optional[list[int]], + scale: torch.Tensor | None, + block_shape: list[int] | None, per_act_token_quant: bool, - out_dtype: Optional[torch.dtype] = torch.float32, + out_dtype: torch.dtype | None = torch.float32, ) -> torch.Tensor: if scale is not None: f32 = torch.float32 @@ -267,10 +266,10 @@ def dequant( def batched_dequant( t: torch.Tensor, - scale: Optional[torch.Tensor], - block_shape: Optional[list[int]], + scale: torch.Tensor | None, + block_shape: list[int] | None, per_act_token_quant: bool, - out_dtype: Optional[torch.dtype] = torch.float32, + out_dtype: torch.dtype | None = torch.float32, ) -> torch.Tensor: if scale is not None: assert t.shape[0] == scale.shape[0] @@ -289,9 +288,9 @@ def native_batched_masked_quant_matmul( B: torch.Tensor, C: torch.Tensor, num_expert_tokens: torch.Tensor, - A_scale: Optional[torch.Tensor] = None, - B_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, + A_scale: torch.Tensor | None = None, + B_scale: torch.Tensor | None = None, + block_shape: list[int] | None = None, per_act_token_quant: bool = False, ) -> torch.Tensor: num_expert_tokens_cpu = num_expert_tokens.clone() diff --git a/tests/kernels/quantization/nvfp4_utils.py b/tests/kernels/quantization/nvfp4_utils.py index 50be6841560b..5e6d54c42e89 100644 --- a/tests/kernels/quantization/nvfp4_utils.py +++ b/tests/kernels/quantization/nvfp4_utils.py @@ -66,9 +66,11 @@ def break_fp4_bytes(a, dtype): return values.reshape(m, n * 2).to(dtype=dtype) +def get_nvfp4_global_scale(a: torch.Tensor): + return (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(torch.float32) + + def quant_nvfp4_tensor(a: torch.Tensor): - a_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to( - torch.float32 - ) + a_global_scale = get_nvfp4_global_scale(a) a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale) return a_quant, a_block_scale, a_global_scale diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index a6dfb5428c52..55f092e7ea69 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -18,12 +18,12 @@ w8a8_triton_block_scaled_mm, ) from vllm.platforms import current_platform -from vllm.utils import has_deep_gemm from vllm.utils.deep_gemm import ( fp8_gemm_nt, get_col_major_tma_aligned_tensor, per_block_cast_to_fp8, ) +from vllm.utils.import_utils import has_deep_gemm if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) diff --git a/tests/kernels/quantization/test_cutlass_w4a8.py b/tests/kernels/quantization/test_cutlass_w4a8.py index a3d524fe90ed..465e24fd7eb9 100644 --- a/tests/kernels/quantization/test_cutlass_w4a8.py +++ b/tests/kernels/quantization/test_cutlass_w4a8.py @@ -6,7 +6,6 @@ """ from dataclasses import dataclass -from typing import Optional import pytest import torch @@ -60,10 +59,10 @@ class TypeConfig: act_type: torch.dtype weight_type: ScalarType - output_type: Optional[torch.dtype] - group_scale_type: Optional[torch.dtype] - channel_scale_type: Optional[torch.dtype] - token_scale_type: Optional[torch.dtype] + output_type: torch.dtype | None + group_scale_type: torch.dtype | None + channel_scale_type: torch.dtype | None + token_scale_type: torch.dtype | None @dataclass @@ -80,7 +79,7 @@ class Tensors: # (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints, # Ch Scales Type, Tok Scales Type) TestTypeTuple = tuple[ - list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool + list[torch.dtype], ScalarType, torch.dtype | None, torch.dtype | None, bool ] TEST_TYPES = [ *( @@ -116,8 +115,8 @@ def cutlass_quantize_and_pack( atype: torch.dtype, w: torch.Tensor, wtype: ScalarType, - stype: Optional[torch.dtype], - group_size: Optional[int], + stype: torch.dtype | None, + group_size: int | None, zero_points: bool = False, ): assert wtype.is_integer(), "TODO: support floating point weights" @@ -143,7 +142,7 @@ def cutlass_quantize_and_pack( def create_test_tensors( - shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int] + shape: tuple[int, int, int], types: TypeConfig, group_size: int | None ) -> Tensors: m, n, k = shape @@ -185,8 +184,8 @@ def create_test_tensors( def mm_test_helper( types: TypeConfig, tensors: Tensors, - group_size: Optional[int] = None, - schedule: Optional[str] = None, + group_size: int | None = None, + schedule: str | None = None, ): # CUTLASS upstream uses fp8 with fastaccum as reference # https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406 diff --git a/tests/kernels/quantization/test_gptq.py b/tests/kernels/quantization/test_gptq.py index 72e4194c1327..7bc7f97ce75b 100644 --- a/tests/kernels/quantization/test_gptq.py +++ b/tests/kernels/quantization/test_gptq.py @@ -26,4 +26,10 @@ def test_gptq_gemm_opcheck(): idx = torch.empty((0,), device="cuda", dtype=torch.int32) use_exllama = True bit = 4 - opcheck(torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, bit)) + # Test both GPTQv1 and GPTQv2 format + opcheck( + torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, True, bit) + ) + opcheck( + torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, False, bit) + ) diff --git a/tests/kernels/quantization/test_machete_mm.py b/tests/kernels/quantization/test_machete_mm.py index b32523bb85d9..efa81de158d3 100644 --- a/tests/kernels/quantization/test_machete_mm.py +++ b/tests/kernels/quantization/test_machete_mm.py @@ -7,7 +7,6 @@ import math from dataclasses import dataclass, fields -from typing import Optional import pytest import torch @@ -50,11 +49,11 @@ class TypeConfig: act_type: torch.dtype weight_type: ScalarType - output_type: Optional[torch.dtype] - group_scale_type: Optional[torch.dtype] - group_zero_type: Optional[torch.dtype] - channel_scale_type: Optional[torch.dtype] - token_scale_type: Optional[torch.dtype] + output_type: torch.dtype | None + group_scale_type: torch.dtype | None + group_zero_type: torch.dtype | None + channel_scale_type: torch.dtype | None + token_scale_type: torch.dtype | None @dataclass @@ -63,10 +62,10 @@ class Tensors: a_ref: torch.Tensor a: torch.Tensor w_q: torch.Tensor - w_g_s: Optional[torch.Tensor] - w_g_zp: Optional[torch.Tensor] - w_ch_s: Optional[torch.Tensor] - w_tok_s: Optional[torch.Tensor] + w_g_s: torch.Tensor | None + w_g_zp: torch.Tensor | None + w_ch_s: torch.Tensor | None + w_tok_s: torch.Tensor | None # (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints, @@ -74,7 +73,7 @@ class Tensors: # NOTE: None "Scale Type" means the act type is floating point # None "Output Type" means the output type is the same as the act type TestTypeTuple = tuple[ - list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool + list[torch.dtype], ScalarType, torch.dtype | None, torch.dtype | None, bool ] TEST_TYPES = [ # GPTQ style @@ -139,11 +138,11 @@ def rand_data(shape, dtype=torch.float16, scale=1, offset=0): return torch.randint(-8, 7, shape, dtype=dtype, device="cuda") -def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): +def maybe_convert_zeropoints(zps: torch.Tensor | None, s: torch.Tensor): return zps if zps is None else -1 * s * (zps.to(s.dtype)) -def group_size_valid(shape: tuple[int, int, int], group_size: Optional[int]) -> bool: +def group_size_valid(shape: tuple[int, int, int], group_size: int | None) -> bool: return group_size is None or group_size == -1 or shape[2] % group_size == 0 @@ -151,8 +150,8 @@ def machete_quantize_and_pack( atype: torch.dtype, w: torch.Tensor, wtype: ScalarType, - stype: Optional[torch.dtype], - group_size: Optional[int], + stype: torch.dtype | None, + group_size: int | None, zero_points: bool = False, ): assert wtype.is_integer(), "TODO: support floating point weights" @@ -178,8 +177,8 @@ def machete_quantize_and_pack( def create_test_tensors( shape: tuple[int, int, int], types: TypeConfig, - group_size: Optional[int], - subset_stride_factor: Optional[int] = None, + group_size: int | None, + subset_stride_factor: int | None = None, ) -> Tensors: m, n, k = shape factor = subset_stride_factor or 1 @@ -243,8 +242,8 @@ def create_test_tensors( def machete_mm_test_helper( types: TypeConfig, tensors: Tensors, - group_size: Optional[int] = None, - schedule: Optional[str] = None, + group_size: int | None = None, + schedule: str | None = None, ): output_ref = torch.matmul(tensors.a_ref, tensors.w_ref) output_ref_type = output_ref.dtype @@ -294,7 +293,7 @@ def machete_mm_test_helper( @pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x)) @pytest.mark.parametrize("types", TEST_TYPES) def test_machete_all_schedules(shape, types: TypeConfig): - group_sizes: list[Optional[int]] = [] + group_sizes: list[int | None] = [] if types.group_scale_type is None: group_sizes = [None] else: @@ -323,7 +322,7 @@ def test_machete_all_schedules(shape, types: TypeConfig): @pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x)) @pytest.mark.parametrize("types", TEST_TYPES) def test_machete_heuristic(shape, types: TypeConfig): - group_sizes: list[Optional[int]] = [] + group_sizes: list[int | None] = [] if types.group_scale_type is None: group_sizes = [None] else: diff --git a/tests/kernels/quantization/test_mxfp4_qutlass.py b/tests/kernels/quantization/test_mxfp4_qutlass.py new file mode 100644 index 000000000000..0bacbef2046b --- /dev/null +++ b/tests/kernels/quantization/test_mxfp4_qutlass.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at). +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pytest +import torch +from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix + +from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn +from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked +from vllm.platforms import current_platform + +if not torch.cuda.is_available(): + pytest.skip("CUDA required for these tests.", allow_module_level=True) + +if not ( + current_platform.has_device_capability(100) + or current_platform.has_device_capability(120) +): + pytest.skip( + reason="Tests require compute capability 10.0 (100) or 12.0 (120).", + allow_module_level=True, + ) + + +# ----- Helpers ----- +def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device): + return ( + deterministic_hadamard_matrix(group_size, dtype=dtype, device=device) + * group_size**-0.5 + ) + + +def _rtne_fp4(x: torch.Tensor): + device = x.device + grid = torch.tensor( + [ + -6.0, + -4.0, + -3.0, + -2.0, + -1.5, + -1.0, + -0.5, + -0.0, + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + ], + dtype=x.dtype, + device=x.device, + ) + grid_int = torch.tensor( + [-1, -2, -3, -4, -5, -6, -7, -8, 0, 1, 2, 3, 4, 5, 6, 7], + dtype=torch.uint8, + device=device, + ) + inds = torch.bucketize(x, grid) + lo, hi = (inds - 1).clamp(min=0, max=15), inds.clamp(min=0, max=15) + g_lo, g_hi = grid[lo], grid[hi] + pick_hi = (g_hi - x < x - g_lo) | (g_hi - x == x - g_lo) & (grid_int[hi] % 2 == 0) + y = torch.where(pick_hi, g_hi, g_lo) + y_int = torch.where(pick_hi, grid_int[hi], grid_int[lo]) + y_int_packed = (y_int[..., 1::2] & 0xF) << 4 | y_int[..., ::2] & 0xF + return y, y_int_packed + + +def _dq_fp4(x_e2m1: torch.Tensor, x_e8m0: torch.Tensor, alpha: float): + device = x_e2m1.device + + x_e2m1_i32 = x_e2m1.view(dtype=torch.uint8).to(dtype=torch.int32) + x_e2m1_unpacked = torch.stack( + [x_e2m1_i32 & 0xF, (x_e2m1_i32 >> 4) & 0xF], dim=-1 + ).flatten(start_dim=-2) + + grid_dq = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + dtype=torch.float64, + device=device, + ) + x_fp4_dq = grid_dq[x_e2m1_unpacked] + scales_dq = x_e8m0.to(torch.float64) + + x_dq = (x_fp4_dq.unflatten(dim=-1, sizes=(-1, 32)) * scales_dq[..., None]).flatten( + start_dim=-2 + ) / alpha + return x_dq, x_fp4_dq, scales_dq + + +def _unpack_mask(clip_mask: torch.Tensor) -> torch.Tensor: + clip_mask_unpacked_dq = torch.zeros( + *clip_mask.shape[:-1], + clip_mask.size(-1) * 8, + dtype=torch.bool, + device=clip_mask.device, + ) + for i in range(8): + clip_mask_unpacked_dq[..., i::8] = (clip_mask >> i) & 1 + return clip_mask_unpacked_dq + + +def _forward_quantize_ref( + x: torch.Tensor, h: torch.Tensor, rot_size: int, quest: bool = True +): + device = x.device + xh_ref64 = ( + x.unflatten(dim=-1, sizes=(-1, rot_size)).to(dtype=torch.float64) + @ h.reshape(rot_size, rot_size).to(dtype=torch.float64) + ).flatten(start_dim=-2) + + if quest: + scales_ref64_ = ( + xh_ref64.unflatten(dim=-1, sizes=(-1, 32)).std(dim=-1, correction=0) + * (2.92247856 / 6.0) + + 1e-8 + ) + else: + abs_max = xh_ref64.unflatten(dim=-1, sizes=(-1, 32)).abs().amax(dim=-1) + scales_ref64_ = abs_max + 1e-8 + + xh_e8m0_ref = scales_ref64_.log2().floor().exp2().to(dtype=torch.float8_e8m0fnu) + scales_ref64 = xh_e8m0_ref.to(dtype=torch.float64) + + xh_scaled_ref64 = ( + xh_ref64.unflatten(dim=-1, sizes=(-1, 32)) / scales_ref64[..., None] + ).flatten(start_dim=-2) + if not quest: + xh_scaled_ref64 *= 3 + + clip_mask_unpacked_ref = xh_scaled_ref64.abs() < 6.0 + clip_mask_ref = torch.zeros( + *x.shape[:-1], x.size(-1) // 8, dtype=torch.uint8, device=device + ) + for i in range(8): + clip_mask_ref |= clip_mask_unpacked_ref[..., i::8].to(dtype=torch.uint8) << i + + xh_fp4_ref, xh_e2m1_ref = _rtne_fp4(xh_scaled_ref64) + xh_dq, xh_fp4_dq, scales_dq = _dq_fp4( + xh_e2m1_ref, xh_e8m0_ref, alpha=1.0 if quest else 3.0 + ) + clip_mask_unpacked_dq = _unpack_mask(clip_mask_ref) + + assert xh_fp4_dq.equal(xh_fp4_ref) + assert scales_dq.equal(scales_ref64) + assert clip_mask_unpacked_dq.equal(clip_mask_unpacked_ref) + + return ( + xh_dq, + clip_mask_unpacked_ref, + (xh_e2m1_ref, xh_e8m0_ref, clip_mask_ref), + ) + + +DTYPE = torch.bfloat16 +DEVICE = torch.device("cuda:0") + +ROT_SIZES = [32, 64, 128] +SEEDS = [0] +BATCHES = [1, 16] + +LLAMA_MODELS = { + "7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)], + "13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)], + "33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)], + "70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)], +} + + +@pytest.fixture(autouse=True) +def _seed_each_test(): + current_platform.seed_everything(0) + np.random.seed(0) + torch.random.manual_seed(0) + + +@pytest.mark.parametrize("rot_size", ROT_SIZES) +@torch.inference_mode() +def test_fused_quantization_absmax(rot_size: int): + dtype, device = DTYPE, DEVICE + h = get_hadamard_matrix(rot_size, dtype, device) + x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0 + + xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, quest=False) + xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="abs_max") + xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32) + xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=3.0) + + torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100) + assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4 + + m, n, k = 1, 504, 4096 + a = torch.randn(m, k, dtype=dtype, device=device) * 25.0 + b = torch.randn(n, k, dtype=dtype, device=device) * 25.0 + + a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="abs_max") + b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="abs_max") + a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0) + b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0) + out_ref = a_dq @ b_dq.transpose(-2, -1) + + a_scale_block = to_blocked(a_e8m0, backend="triton") + b_scale_block = to_blocked(b_e8m0, backend="triton") + alpha = torch.tensor([1.0], device=device) + out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha) + assert out.equal(out_ref.to(dtype=out.dtype)) + + +@pytest.mark.parametrize("rot_size", ROT_SIZES) +@torch.inference_mode() +def test_fused_quantization_quest(rot_size: int): + dtype, device = DTYPE, DEVICE + h = get_hadamard_matrix(rot_size, dtype, device) + x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0 + + xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, quest=True) + xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="quest") + xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32) + xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=1.0) + + torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100) + assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4 + + m, n, k = 504, 504, 2048 + a = torch.randn(m, k, dtype=dtype, device=device) * 25.0 + b = torch.randn(n, k, dtype=dtype, device=device) * 25.0 + + a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="quest") + b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="quest") + a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0) + b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0) + out_ref = a_dq @ b_dq.transpose(-2, -1) + + a_scale_block = to_blocked(a_e8m0, backend="triton") + b_scale_block = to_blocked(b_e8m0, backend="triton") + alpha = torch.tensor([1.0], device=device) + out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha) + assert out.equal(out_ref.to(dtype=out.dtype)) + + +@pytest.mark.parametrize("model", list(LLAMA_MODELS.keys())) +@pytest.mark.parametrize("layer_idx", [0, 1, 2, 3]) +@pytest.mark.parametrize("batch", [1, 16]) +@pytest.mark.parametrize("had_size", ROT_SIZES) +@torch.inference_mode() +def test_llama_shapes(model: str, layer_idx: int, batch: int, had_size: int): + dtype, device = DTYPE, DEVICE + m = batch + k, n = LLAMA_MODELS[model][layer_idx] + + h = get_hadamard_matrix(had_size, dtype, device) + + a = torch.rand(m, k, dtype=dtype, device=device) * 25.0 + b = torch.rand(n, k, dtype=dtype, device=device) * 25.0 + + a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="quest") + b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="quest") + + a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0) + b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0) + out_ref = a_dq @ b_dq.transpose(-2, -1) + + a_scale_block = to_blocked(a_e8m0, backend="triton") + b_scale_block = to_blocked(b_e8m0, backend="triton") + alpha = torch.tensor([1.0], device=device) + out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha) + assert out.equal(out_ref.to(dtype=out.dtype)) diff --git a/tests/kernels/quantization/test_nvfp4_qutlass.py b/tests/kernels/quantization/test_nvfp4_qutlass.py new file mode 100644 index 000000000000..3824a080f504 --- /dev/null +++ b/tests/kernels/quantization/test_nvfp4_qutlass.py @@ -0,0 +1,268 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at). +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pytest +import torch +from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix + +from vllm import _custom_ops as ops # use existing nvfp4 gemm in vllm +from vllm._custom_ops import fusedQuantizeNv +from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked +from vllm.platforms import current_platform + +if not torch.cuda.is_available(): + pytest.skip("CUDA required for these tests.", allow_module_level=True) + +if not ( + current_platform.has_device_capability(100) + or current_platform.has_device_capability(120) +): + pytest.skip( + reason="Tests require compute capability 10.0 (100) or 12.0 (120).", + allow_module_level=True, + ) + + +# ----- Helpers ----- +def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device): + return ( + deterministic_hadamard_matrix(group_size, dtype=dtype, device=device) + * group_size**-0.5 + ) + + +def _rtne_fp4(x: torch.Tensor): + device = x.device + grid = torch.tensor( + [ + -6.0, + -4.0, + -3.0, + -2.0, + -1.5, + -1.0, + -0.5, + -0.0, + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + ], + dtype=x.dtype, + device=x.device, + ) + grid_int = torch.tensor( + [-1, -2, -3, -4, -5, -6, -7, -8, 0, 1, 2, 3, 4, 5, 6, 7], + dtype=torch.uint8, + device=device, + ) + inds = torch.bucketize(x, grid) + lo, hi = (inds - 1).clamp(min=0, max=15), inds.clamp(min=0, max=15) + g_lo, g_hi = grid[lo], grid[hi] + pick_hi = (g_hi - x < x - g_lo) | (g_hi - x == x - g_lo) & (grid_int[hi] % 2 == 0) + y = torch.where(pick_hi, g_hi, g_lo) + y_int = torch.where(pick_hi, grid_int[hi], grid_int[lo]) + y_int_packed = (y_int[..., 1::2] & 0xF) << 4 | y_int[..., ::2] & 0xF + return y, y_int_packed + + +def _dq_fp4(x_e2m1: torch.Tensor, x_e4m3: torch.Tensor, alpha: float): + device = x_e2m1.device + + x_e2m1_i32 = x_e2m1.view(dtype=torch.uint8).to(dtype=torch.int32) + x_e2m1_unpacked = torch.stack( + [x_e2m1_i32 & 0xF, (x_e2m1_i32 >> 4) & 0xF], dim=-1 + ).flatten(start_dim=-2) + + grid_dq = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + dtype=torch.float64, + device=device, + ) + x_fp4_dq = grid_dq[x_e2m1_unpacked] + + scales_dq = x_e4m3.to(torch.float64) + x_dq = (x_fp4_dq.unflatten(dim=-1, sizes=(-1, 16)) * scales_dq[..., None]).flatten( + start_dim=-2 + ) / alpha # * (4. / 3.) + return x_dq, x_fp4_dq, scales_dq + + +def _unpack_mask(clip_mask: torch.Tensor) -> torch.Tensor: + clip_mask_unpacked_dq = torch.zeros( + *clip_mask.shape[:-1], + clip_mask.size(-1) * 8, + dtype=torch.bool, + device=clip_mask.device, + ) + for i in range(8): + clip_mask_unpacked_dq[..., i::8] = (clip_mask >> i) & 1 + return clip_mask_unpacked_dq + + +def _forward_quantize_ref(x: torch.Tensor, h: torch.Tensor, rot_size: int): + device = x.device + + xh_ref64 = ( + x.unflatten(dim=-1, sizes=(-1, rot_size)).to(dtype=torch.float64) + @ h.reshape(rot_size, rot_size).to(dtype=torch.float64) + ).flatten(start_dim=-2) + + abs_max = xh_ref64.unflatten(dim=-1, sizes=(-1, 16)).abs().amax(dim=-1) + scales_ref64_ = abs_max + 1e-8 + + xh_e4m3_ref = scales_ref64_.to(dtype=torch.float8_e4m3fn) + scales_ref64 = xh_e4m3_ref.to(dtype=torch.float64) + xh_scaled_ref64 = ( + xh_ref64.unflatten(dim=-1, sizes=(-1, 16)) / scales_ref64[..., None] + ).flatten(start_dim=-2) + + xh_scaled_ref64 *= 6.0 + + clip_mask_unpacked_ref = xh_scaled_ref64.abs() < 6.0 + clip_mask_ref = torch.zeros( + *x.shape[:-1], x.size(-1) // 8, dtype=torch.uint8, device=device + ) + for i in range(8): + clip_mask_ref |= clip_mask_unpacked_ref[..., i::8].to(dtype=torch.uint8) << i + + xh_fp4_ref, xh_e2m1_ref = _rtne_fp4(xh_scaled_ref64) + xh_dq, xh_fp4_dq, scales_dq = _dq_fp4(xh_e2m1_ref, xh_e4m3_ref, 6.0) + clip_mask_unpacked_dq = _unpack_mask(clip_mask_ref) + + assert xh_fp4_dq.equal(xh_fp4_ref) + assert scales_dq.equal(scales_ref64) + assert clip_mask_unpacked_dq.equal(clip_mask_unpacked_ref) + + return ( + xh_dq, + clip_mask_unpacked_ref, + (xh_e2m1_ref, xh_e4m3_ref, clip_mask_ref), + ) + + +DTYPE = torch.bfloat16 +DEVICE = torch.device("cuda:0") +ROT_SIZES = [16, 32, 64, 128] +GLOBAL_SCALES = [6.0] + +LLAMA_MODELS = { + "7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)], + "13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)], + "33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)], + "70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)], +} + + +@pytest.fixture(autouse=True) +def _seed_each_test(): + current_platform.seed_everything(0) + np.random.seed(0) + torch.random.manual_seed(0) + + +@pytest.mark.parametrize("rot_size", ROT_SIZES) +@pytest.mark.parametrize("global_scale_value", GLOBAL_SCALES) +@torch.inference_mode() +def test_fused_quantization(rot_size: int, global_scale_value: float): + dtype, device = DTYPE, DEVICE + h = get_hadamard_matrix(rot_size, dtype, device) + x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0 + global_scale = torch.tensor([global_scale_value], device=device) + + xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size) + xh_e2m1, xh_e4m3 = fusedQuantizeNv(x, h, global_scale) + xh_e4m3 = xh_e4m3.reshape(2, 4096, 4096 // 16) + xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e4m3, alpha=global_scale_value) + + torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100) + assert (xh_dq != xh_dq_ref).float().mean() <= 1e-1 + + m, n, k = 504, 4096 * 2, 4096 + a = torch.randn(m, k, dtype=dtype, device=device) * 25.0 + b = torch.randn(n, k, dtype=dtype, device=device) * 25.0 + + a_e2m1, a_e4m3 = fusedQuantizeNv(a, h, global_scale) + b_e2m1, b_e4m3 = fusedQuantizeNv(b, h, global_scale) + + a_dq, *_ = _dq_fp4(a_e2m1, a_e4m3[:m, :k], alpha=1.0) + b_dq, *_ = _dq_fp4(b_e2m1, b_e4m3[:n, :k], alpha=1.0) + out_ref = a_dq @ b_dq.transpose(-2, -1) + + a_scale_block = to_blocked(a_e4m3, backend="triton").view(-1, k // 16) + b_scale_block = to_blocked(b_e4m3, backend="triton").view(-1, k // 16) + alpha = torch.tensor([1.0], device=device) + out = ops.cutlass_scaled_fp4_mm( + a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha, torch.bfloat16 + ) + assert out.equal(out_ref.to(dtype=out.dtype)) + + +@pytest.mark.parametrize("model", list(LLAMA_MODELS.keys())) +@pytest.mark.parametrize("layer_idx", [0, 1, 2, 3]) +@pytest.mark.parametrize("batch", [1, 16]) +@pytest.mark.parametrize("rot_size", ROT_SIZES) +@torch.inference_mode() +def test_llama_shapes(model: str, layer_idx: int, batch: int, rot_size: int): + dtype, device = DTYPE, DEVICE + m = batch + k, n = LLAMA_MODELS[model][layer_idx] + + h = get_hadamard_matrix(rot_size, dtype, device) + + a = torch.randn(m, k, dtype=dtype, device=device) * 25.0 + b = torch.randn(n, k, dtype=dtype, device=device) * 25.0 + + global_scale = torch.tensor([1.0], device=device) + + a_e2m1, a_e4m3 = fusedQuantizeNv(a, h, global_scale) + b_e2m1, b_e4m3 = fusedQuantizeNv(b, h, global_scale) + + a_dq, *_ = _dq_fp4(a_e2m1, a_e4m3[:m, :k], alpha=1.0) + b_dq, *_ = _dq_fp4(b_e2m1, b_e4m3[:n, :k], alpha=1.0) + out_ref = a_dq @ b_dq.transpose(-2, -1) + + a_scale_block = to_blocked(a_e4m3, backend="triton").view(-1, k // 16) + b_scale_block = to_blocked(b_e4m3, backend="triton").view(-1, k // 16) + alpha = torch.tensor([1.0], device=device) + out = ops.cutlass_scaled_fp4_mm( + a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha, torch.bfloat16 + ) + assert out.equal(out_ref.to(dtype=out.dtype)) diff --git a/tests/kernels/quantization/test_triton_scaled_mm.py b/tests/kernels/quantization/test_triton_scaled_mm.py index 1026332d99f8..6633a8bbd3c6 100644 --- a/tests/kernels/quantization/test_triton_scaled_mm.py +++ b/tests/kernels/quantization/test_triton_scaled_mm.py @@ -6,7 +6,6 @@ """ import importlib -from typing import Optional import pytest import torch @@ -27,7 +26,7 @@ def torch_scaled_mm( scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: type[torch.dtype], - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: out = torch.mm(a.to(torch.float32), b.to(torch.float32)) out = scale_a * out diff --git a/tests/kernels/test_fla_layernorm_guard.py b/tests/kernels/test_fla_layernorm_guard.py new file mode 100644 index 000000000000..f944c6dcfa73 --- /dev/null +++ b/tests/kernels/test_fla_layernorm_guard.py @@ -0,0 +1,388 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +import torch.nn.functional as F + +from vllm.model_executor.layers.fla.ops.layernorm_guard import ( + layer_norm_fwd, + layernorm_fn, + rms_norm_ref, +) +from vllm.platforms import current_platform + + +def layer_norm_ref( + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, +): + """Reference implementation for both layer norm and RMS norm.""" + if is_rms_norm: + # Use the imported rms_norm_ref for RMS norm cases + return rms_norm_ref( + x, + weight, + bias, + z=z, + eps=eps, + group_size=group_size, + norm_before_gate=norm_before_gate, + upcast=True, + ) + + # Layer norm implementation + dtype = x.dtype + x = x.float() + weight = weight.float() + bias = bias.float() if bias is not None else None + z = z.float() if z is not None else None + + if z is not None and not norm_before_gate: + x = x * F.silu(z) + + if group_size is None: + # Layer norm: subtract mean + mean = x.mean(dim=-1, keepdim=True) + var = ((x - mean).square()).mean(dim=-1, keepdim=True) + rstd = 1 / torch.sqrt(var + eps) + out = (x - mean) * rstd * weight + if bias is not None: + out = out + bias + else: + # Group norm + from einops import rearrange + + x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) + mean = x_group.mean(dim=-1, keepdim=True) + var = ((x_group - mean).square()).mean(dim=-1, keepdim=True) + rstd = 1 / torch.sqrt(var + eps) + x_group = (x_group - mean) * rstd + out = rearrange(x_group, "... g d -> ... (g d)") * weight + if bias is not None: + out = out + bias + + if z is not None and norm_before_gate: + out *= F.silu(z) + + return out.to(dtype) + + +DTYPES = [torch.bfloat16, torch.float32] +# Test various M sizes to ensure rows_per_block logic works correctly +NUM_TOKENS = [ + 1, + 7, + 16, + 63, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 5789, + 8189, + 8191, + 16383, + 32767, +] +HIDDEN_SIZES = [64, 128, 256, 1024] +GROUP_SIZES = [None, 64, 128] # None means full hidden size +NORM_BEFORE_GATE = [True, False] +IS_RMS_NORM = [True, False] +SEEDS = [0, 42] + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("is_rms_norm", IS_RMS_NORM) +@torch.inference_mode() +def test_layer_norm_fwd_basic( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + seed: int, + is_rms_norm: bool, +) -> None: + """Test basic layer norm forward pass without z (gate) tensor.""" + current_platform.seed_everything(seed) + device = torch.device("cuda:0") + + # Create inputs + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + bias = None if is_rms_norm else torch.randn(hidden_size, dtype=dtype, device=device) + eps = 1e-6 + + # Run the triton kernel + out, mean, rstd = layer_norm_fwd( + x, weight, bias, eps, z=None, is_rms_norm=is_rms_norm + ) + + # Run reference implementation + ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=is_rms_norm) + + # Check outputs + assert out.shape == x.shape + assert out.dtype == x.dtype + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + # Check mean and rstd shapes + if not is_rms_norm: + assert mean.shape == (num_tokens,) + assert rstd.shape == (num_tokens,) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", [128, 256, 1024]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("norm_before_gate", NORM_BEFORE_GATE) +@pytest.mark.parametrize("is_rms_norm", IS_RMS_NORM) +@torch.inference_mode() +def test_layer_norm_fwd_with_gate( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + norm_before_gate: bool, + is_rms_norm: bool, +) -> None: + """Test layer norm forward pass with z (gate) tensor.""" + current_platform.seed_everything(42) + device = torch.device("cuda:0") + + # Create inputs + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + z = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + bias = None if is_rms_norm else torch.randn(hidden_size, dtype=dtype, device=device) + eps = 1e-6 + + # Run the triton kernel + out, mean, rstd = layer_norm_fwd( + x, + weight, + bias, + eps, + z=z, + norm_before_gate=norm_before_gate, + is_rms_norm=is_rms_norm, + ) + + # Run reference implementation + ref_out = layer_norm_ref( + x, + weight, + bias, + z=z, + eps=eps, + norm_before_gate=norm_before_gate, + is_rms_norm=is_rms_norm, + ) + + # Check outputs + assert out.shape == x.shape + assert out.dtype == x.dtype + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("num_tokens", [128, 512]) +@pytest.mark.parametrize("hidden_size", [512, 1024]) +@pytest.mark.parametrize("group_size", [64, 128, 256]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("is_rms_norm", IS_RMS_NORM) +@torch.inference_mode() +def test_layer_norm_fwd_with_groups( + num_tokens: int, + hidden_size: int, + group_size: int, + dtype: torch.dtype, + is_rms_norm: bool, +) -> None: + """Test layer norm forward pass with group normalization.""" + if hidden_size % group_size != 0: + pytest.skip( + f"hidden_size {hidden_size} not divisible by group_size {group_size}" + ) + + current_platform.seed_everything(42) + device = torch.device("cuda:0") + + # Create inputs + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + bias = None if is_rms_norm else torch.randn(hidden_size, dtype=dtype, device=device) + eps = 1e-6 + + ngroups = hidden_size // group_size + + # Run the triton kernel + out, mean, rstd = layer_norm_fwd( + x, weight, bias, eps, z=None, group_size=group_size, is_rms_norm=is_rms_norm + ) + + # Run reference implementation + ref_out = layer_norm_ref( + x, weight, bias, z=None, eps=eps, group_size=group_size, is_rms_norm=is_rms_norm + ) + + # Check outputs + assert out.shape == x.shape + assert out.dtype == x.dtype + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + # Check mean and rstd shapes for groups + if not is_rms_norm: + assert mean.shape == (ngroups * num_tokens,) + assert rstd.shape == (ngroups * num_tokens,) + + +@pytest.mark.parametrize("num_tokens", [7, 63, 128, 513, 1024, 2049]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@torch.inference_mode() +def test_layer_norm_rows_per_block( + num_tokens: int, + dtype: torch.dtype, +) -> None: + """Test that rows_per_block logic works correctly for various M sizes.""" + current_platform.seed_everything(42) + device = torch.device("cuda:0") + hidden_size = 1024 + + # Create inputs + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + bias = torch.randn(hidden_size, dtype=dtype, device=device) + eps = 1e-6 + + # Run the triton kernel + out, mean, rstd = layer_norm_fwd(x, weight, bias, eps, z=None, is_rms_norm=False) + + # Run reference implementation + ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=False) + + # Check outputs + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@torch.inference_mode() +def test_strided_input(dtype: torch.dtype) -> None: + """Test that the kernel handles non-contiguous (strided) + inputs correctly.""" + current_platform.seed_everything(42) + device = torch.device("cuda:0") + num_tokens = 128 + hidden_size = 1024 + + # Create a larger tensor and take a strided slice + x_large = torch.randn(num_tokens, hidden_size * 2, dtype=dtype, device=device) + x = x_large[:, :hidden_size] + + # Make it contiguous for the kernel + x_contiguous = x.contiguous() + + weight = torch.randn(hidden_size, dtype=dtype, device=device) + bias = torch.randn(hidden_size, dtype=dtype, device=device) + eps = 1e-6 + + # Run the triton kernel with contiguous input + out, mean, rstd = layer_norm_fwd( + x_contiguous, weight, bias, eps, z=None, is_rms_norm=False + ) + + # Run reference implementation + ref_out = layer_norm_ref( + x_contiguous, weight, bias, z=None, eps=eps, is_rms_norm=False + ) + + # Check outputs + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("num_tokens", [1, 128, 2048]) +@pytest.mark.parametrize("hidden_size", [768, 4096]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@torch.inference_mode() +def test_output_buffer_provided( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, +) -> None: + """Test that the kernel works when an output buffer is provided.""" + current_platform.seed_everything(42) + device = torch.device("cuda:0") + + # Create inputs + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + bias = torch.randn(hidden_size, dtype=dtype, device=device) + eps = 1e-6 + + # Pre-allocate output buffer + out_buffer = torch.empty_like(x) + + # Run the triton kernel with provided output + out, mean, rstd = layer_norm_fwd( + x, weight, bias, eps, z=None, out=out_buffer, is_rms_norm=False + ) + + # Check that the provided buffer was used + assert out.data_ptr() == out_buffer.data_ptr() + + # Run reference implementation + ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=False) + + # Check outputs + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize( + "shape", + [ + (4, 16, 1024), # 3D tensor + (2, 8, 512, 256), # 4D tensor + ], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@torch.inference_mode() +def test_multidimensional_input( + shape: tuple, + dtype: torch.dtype, +) -> None: + """Test that the autograd function handles multidimensional inputs.""" + current_platform.seed_everything(42) + device = torch.device("cuda:0") + hidden_size = shape[-1] + + # Create inputs + x = torch.randn(*shape, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + bias = torch.randn(hidden_size, dtype=dtype, device=device) + eps = 1e-6 + + # Run through autograd function + out = layernorm_fn(x, weight, bias, z=None, eps=eps) + + # Run reference implementation + ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=False) + + # Check outputs + assert out.shape == x.shape + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + # Run a quick smoke test + test_layer_norm_fwd_basic(128, 1024, torch.float16, 42, False) + test_layer_norm_fwd_with_gate(128, 1024, torch.float16, True, False) + test_layer_norm_rows_per_block(513, torch.float16) + print("All smoke tests passed!") diff --git a/tests/kernels/test_onednn.py b/tests/kernels/test_onednn.py index 9f78c177a81f..c9eca1f86d3a 100644 --- a/tests/kernels/test_onednn.py +++ b/tests/kernels/test_onednn.py @@ -2,8 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Integration tests for FlexAttention backend vs default backend""" -from typing import Optional - import pytest import torch @@ -38,8 +36,8 @@ def ref_int8_scaled_mm( b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, - azp: Optional[torch.Tensor], - bias: Optional[torch.Tensor], + azp: torch.Tensor | None, + bias: torch.Tensor | None, output_type: torch.dtype, ): if azp is not None: diff --git a/tests/kernels/test_top_k_per_row.py b/tests/kernels/test_top_k_per_row.py index f52cddc8c370..cadda27b49e9 100644 --- a/tests/kernels/test_top_k_per_row.py +++ b/tests/kernels/test_top_k_per_row.py @@ -10,6 +10,8 @@ # Test parameters NUM_ROWS = [1, 32, 2050] TOP_K_VALUES = [2048] +BATCH_SIZE = [1, 2, 4, 2048, 4096] +NEXT_N = [1, 2, 4, 8] def create_random_logits( @@ -39,10 +41,9 @@ def create_row_boundaries( def compare_top_k_results( + logits: torch.Tensor, cuda_indices: torch.Tensor, - cuda_values: torch.Tensor, torch_indices: torch.Tensor, - torch_values: torch.Tensor, row_starts: torch.Tensor, row_ends: torch.Tensor, top_k: int, @@ -70,8 +71,9 @@ def compare_top_k_results( continue # Any difference in elements, compare the values - cuda_row_values = cuda_values[row_idx][:num_valid].cpu() - torch_row_values = torch_values[row_idx][:num_valid].cpu() + logits_row = logits[row_idx] + cuda_row_values = [logits_row[i] for i in cuda_row_indices] + torch_row_values = [logits_row[i] for i in torch_row_indices] cuda_only_values, torch_only_values = [], [] for idx in cuda_set - torch_set: @@ -114,8 +116,7 @@ def test_top_k_per_row( logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42) # Create output tensors - indices = torch.empty((num_rows, 2048), dtype=torch.int32, device="cuda") - values = torch.empty((num_rows, 2048), dtype=torch.float32, device="cuda") + indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") # Run CUDA implementation torch.ops._C.top_k_per_row( @@ -123,15 +124,13 @@ def test_top_k_per_row( row_starts, row_ends, indices, - values, num_rows, - top_k, logits.stride(0), logits.stride(1), ) # Run reference implementation - torch_values, torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1) + torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1] mask_lo = torch_indices >= 0 mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0 mask = mask_lo & mask_hi @@ -139,5 +138,61 @@ def test_top_k_per_row( # Compare results assert compare_top_k_results( - indices, values, torch_indices, torch_values, row_starts, row_ends, top_k + logits, indices, torch_indices, row_starts, row_ends, top_k + ), "CUDA top_k_per_row results don't match torch.topk" + + +@pytest.mark.parametrize("top_k", TOP_K_VALUES) +@pytest.mark.parametrize("batch_size", BATCH_SIZE) +@pytest.mark.parametrize("next_n", NEXT_N) +@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") +@torch.inference_mode() +def test_top_k_per_row_decode( + top_k: int, + batch_size: int, + next_n: int, +) -> None: + """ + Test top_k_per_row with seq_lens tensor. + """ + torch.set_default_device("cuda:0") + + # Create test data + num_rows = batch_size * next_n + vocab_size = 20000 + seq_lens = torch.randint( + vocab_size, (batch_size,), dtype=torch.int32, device="cuda" + ) + row_starts = torch.zeros(num_rows, dtype=torch.int32, device="cuda") + row_indices = torch.arange(num_rows, device="cuda") // next_n + next_n_offset = torch.arange(num_rows, device="cuda") % next_n + row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1 + logits = create_random_logits(row_starts, row_ends, vocab_size, torch.float32, 42) + + # Create output tensors + indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") + + # Run CUDA implementation + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + seq_lens, + indices, + num_rows, + logits.stride(0), + logits.stride(1), + ) + + torch.cuda.synchronize() + + # Run reference implementation + torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1] + mask_lo = torch_indices >= 0 + mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0 + mask = mask_lo & mask_hi + torch_indices = torch_indices.masked_fill(~mask, -1) + + # Compare results + assert compare_top_k_results( + logits, indices, torch_indices, row_starts, row_ends, top_k ), "CUDA top_k_per_row results don't match torch.topk" diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 015424d9ee0f..eb00bc72b4b0 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -7,7 +7,7 @@ import unittest from collections.abc import Sequence from numbers import Number -from typing import Any, NamedTuple, Optional, Union +from typing import Any, NamedTuple import pytest import torch @@ -22,8 +22,8 @@ STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_XFORMERS_ATTN_VAL, - make_tensor_with_pad, ) +from vllm.utils.torch_utils import make_tensor_with_pad # For now, disable "test_aot_dispatch_dynamic" since there are some # bugs related to this test in PyTorch 2.4. @@ -96,10 +96,10 @@ class PackedQKVInputs(NamedTuple): query: torch.Tensor key: torch.Tensor value: torch.Tensor - q_start_loc_list: Optional[list[int]] - kv_start_loc_list: Optional[list[int]] - q_seq_lens: Optional[list[int]] - kv_seq_lens: Optional[list[int]] + q_start_loc_list: list[int] | None + kv_start_loc_list: list[int] | None + q_seq_lens: list[int] | None + kv_seq_lens: list[int] | None class PackedQKVO(NamedTuple): @@ -115,7 +115,7 @@ class PackedQKVO(NamedTuple): x head_size) known-correct attention output """ - packed_qkv: Optional[PackedQKVInputs] + packed_qkv: PackedQKVInputs | None ideal_output: torch.Tensor @@ -149,12 +149,12 @@ class PhaseTestParameters(NamedTuple): """ packed_qkvo: PackedQKVO - kv_mmap: Optional[KVMemoryMap] + kv_mmap: KVMemoryMap | None def maybe_make_int_tensor( - _list: Optional[list[int]], - device: Union[torch.device, str], + _list: list[int] | None, + device: torch.device | str, ) -> torch.Tensor: """ Convert Python int list to a 1D int torch.Tensor on `device` @@ -170,8 +170,8 @@ def maybe_make_int_tensor( def maybe_make_long_tensor( - _list: Optional[list[int]], - device: Union[torch.device, str], + _list: list[int] | None, + device: torch.device | str, ) -> torch.Tensor: """ Convert Python int list to a 1D long torch.Tensor on `device` @@ -186,7 +186,7 @@ def maybe_make_long_tensor( ) -def maybe_max(_list: Optional[list]) -> Optional[Number]: +def maybe_max(_list: list | None) -> Number | None: """ Returns: @@ -241,9 +241,9 @@ def ref_masked_attention( key: torch.Tensor, value: torch.Tensor, scale: float, - custom_mask: Optional[torch.Tensor] = None, - q_seq_lens: Optional[list] = None, - kv_seq_lens: Optional[list] = None, + custom_mask: torch.Tensor | None = None, + q_seq_lens: list | None = None, + kv_seq_lens: list | None = None, ) -> torch.Tensor: """ "Golden" masked attention reference. Supports two types of masking: @@ -302,11 +302,11 @@ def ref_masked_attention( def make_qkv( batch_size: int, max_q_seq_len: int, - max_kv_seq_len: Optional[int], + max_kv_seq_len: int | None, num_heads: int, head_size: int, - device: Union[torch.device, str], - force_kv_seq_lens: Optional[list[int]] = None, + device: torch.device | str, + force_kv_seq_lens: list[int] | None = None, attn_type: AttentionType = AttentionType.ENCODER_DECODER, force_max_len: bool = False, ) -> tuple[QKVInputs, QKVInputs, QKVInputs]: @@ -436,7 +436,7 @@ def make_qkv( def pack_tensor( - unpacked_tensor: torch.Tensor, seq_lens: list[int], device: Union[torch.device, str] + unpacked_tensor: torch.Tensor, seq_lens: list[int], device: torch.device | str ) -> tuple[torch.Tensor, list[int]]: """ Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an @@ -470,7 +470,7 @@ def pack_tensor( return packed_tensor, start_loc_list -def pack_qkv(qkv: QKVInputs, device: Union[torch.device, str]) -> PackedQKVInputs: +def pack_qkv(qkv: QKVInputs, device: torch.device | str) -> PackedQKVInputs: """ Individually pack each of Q, K and V, each with dimensions batch_size x padded_seq_len x num_heads x head_size, into respective number_of_tokens x @@ -594,19 +594,19 @@ def make_alibi_bias( def _make_metadata_tensors( - seq_lens: Optional[list[int]], - context_lens: Optional[list[int]], - encoder_seq_lens: Optional[list[int]], - device: Union[torch.device, str], + seq_lens: list[int] | None, + context_lens: list[int] | None, + encoder_seq_lens: list[int] | None, + device: torch.device | str, ) -> tuple[ torch.Tensor, torch.Tensor, Any, Any, - Optional[torch.Tensor], + torch.Tensor | None, torch.Tensor, torch.Tensor, - Optional[int], + int | None, ]: """ Build scalar & tensor values required to build attention metadata structure. @@ -678,7 +678,7 @@ def make_kv_cache( num_heads: int, head_size: int, block_size: int, - device: Union[torch.device, str], + device: torch.device | str, backend: str, default_val: float = 0.0, ) -> torch.Tensor: @@ -726,18 +726,18 @@ def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: return (num_tokens + block_size) // block_size -def make_empty_slot_mapping_tensor(device: Union[torch.device, str]): +def make_empty_slot_mapping_tensor(device: torch.device | str): return maybe_make_long_tensor([], device) -def make_empty_block_tables_tensor(device: Union[torch.device, str]): +def make_empty_block_tables_tensor(device: torch.device | str): return torch.tensor([], device=device) def split_slot_mapping( slot_mapping_list: torch.Tensor, seq_lens: list[int], - device: Union[torch.device, str], + device: torch.device | str, ): """ Split a slot mapping into valid prefill- and decode-phase slot mappings. @@ -799,7 +799,7 @@ def split_slot_mapping( def make_block_tables_slot_mapping( block_size: int, seq_lens: list[int], - device: Union[torch.device, str], + device: torch.device | str, block_base_addr: int = 0, ) -> tuple[torch.Tensor, list[int], int]: """ @@ -880,11 +880,11 @@ def make_block_tables_slot_mapping( def make_test_metadata( attn_backend: _Backend, is_prompt: bool, - seq_lens: Optional[list[int]], - decoder_test_params: Optional[PhaseTestParameters], - device: Union[torch.device, str], - encoder_test_params: Optional[PhaseTestParameters] = None, - cross_test_params: Optional[PhaseTestParameters] = None, + seq_lens: list[int] | None, + decoder_test_params: PhaseTestParameters | None, + device: torch.device | str, + encoder_test_params: PhaseTestParameters | None = None, + cross_test_params: PhaseTestParameters | None = None, ) -> AttentionMetadata: """ Construct fake attention metadata for a given test phase @@ -1142,16 +1142,16 @@ def torch_experts( topk_weight: torch.Tensor, topk_ids: torch.Tensor, global_num_experts: int = -1, - b_bias1: Optional[torch.Tensor] = None, - b_bias2: Optional[torch.Tensor] = None, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - quant_dtype: Optional[torch.dtype] = None, + b_bias1: torch.Tensor | None = None, + b_bias2: torch.Tensor | None = None, + expert_map: torch.Tensor | None = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + quant_dtype: torch.dtype | None = None, per_act_token_quant=False, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, apply_router_weights_on_input: bool = False, ) -> torch.Tensor: assert ( @@ -1261,10 +1261,10 @@ def torch_moe( w2: torch.Tensor, score: torch.Tensor, topk: int, - b_bias1: Optional[torch.Tensor] = None, - b_bias2: Optional[torch.Tensor] = None, + b_bias1: torch.Tensor | None = None, + b_bias2: torch.Tensor | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, ) -> torch.Tensor: score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) @@ -1298,15 +1298,13 @@ def torch_moe_single(a, w, score, topk): # A special version of op check that has a restricted default set of test_utils # and a patched version of allclose that supports fp8 types. def opcheck( - op: Union[ - torch._ops.OpOverload, - torch._ops.OpOverloadPacket, - torch._library.custom_ops.CustomOpDef, - ], + op: torch._ops.OpOverload + | torch._ops.OpOverloadPacket + | torch._library.custom_ops.CustomOpDef, args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, + kwargs: dict[str, Any] | None = None, *, - test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, + test_utils: str | Sequence[str] = ALL_OPCHECK_TEST_UTILS, raise_exception: bool = True, cond: bool = True, ) -> dict[str, str]: @@ -1338,7 +1336,7 @@ def baseline_scaled_mm( scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: type[torch.dtype], - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: # We treat N-dimensional group scaling as extended numpy-style broadcasting # in numpy simply stretches dimensions with an extent of 1 to match diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index f805a74a4dba..2a688216f25e 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -230,6 +230,26 @@ def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") +@pytest.fixture(scope="session") +def deepseekv2_lora_files(): + return snapshot_download(repo_id="wuchen01/DeepSeek-V2-Lite-Chat-All-LoRA") + + +@pytest.fixture(scope="session") +def gptoss20b_lora_files(): + return snapshot_download(repo_id="LevinZheng/gpt-oss-20b-lora-adapter") + + +@pytest.fixture(scope="session") +def qwen3moe_lora_files(): + return snapshot_download(repo_id="jeeejeee/qwen3-moe-text2sql-spider") + + +@pytest.fixture(scope="session") +def olmoe_lora_files(): + return snapshot_download(repo_id="jeeejeee/olmoe-instruct-text2sql-spider") + + @pytest.fixture def reset_default_device(): """ diff --git a/tests/lora/test_add_lora.py b/tests/lora/test_add_lora.py index 2f28253bce53..9a82ab99ea9c 100644 --- a/tests/lora/test_add_lora.py +++ b/tests/lora/test_add_lora.py @@ -12,7 +12,7 @@ from vllm.inputs import TextPrompt from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams -from vllm.utils import merge_async_iterators +from vllm.utils.async_utils import merge_async_iterators MODEL_PATH = "zai-org/chatglm3-6b" LORA_RANK = 64 diff --git a/tests/lora/test_chatglm3_tp.py b/tests/lora/test_chatglm3_tp.py index d8058c5f87a8..8f42243387d2 100644 --- a/tests/lora/test_chatglm3_tp.py +++ b/tests/lora/test_chatglm3_tp.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import vllm +import vllm.config from vllm.lora.request import LoRARequest from ..utils import create_new_process_for_each_test, multi_gpu_test @@ -53,12 +54,12 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: def test_chatglm3_lora(chatglm3_lora_files): llm = vllm.LLM( MODEL_PATH, - max_model_len=1024, + max_model_len=512, enable_lora=True, - max_loras=4, + max_loras=2, + max_num_seqs=16, max_lora_rank=64, trust_remote_code=True, - enable_chunked_prefill=True, ) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) @@ -70,18 +71,20 @@ def test_chatglm3_lora(chatglm3_lora_files): @multi_gpu_test(num_gpus=4) -@create_new_process_for_each_test() def test_chatglm3_lora_tp4(chatglm3_lora_files): llm = vllm.LLM( MODEL_PATH, - max_model_len=1024, + max_model_len=512, enable_lora=True, - max_loras=4, + max_loras=2, max_lora_rank=64, + max_num_seqs=16, tensor_parallel_size=4, trust_remote_code=True, fully_sharded_loras=False, - enable_chunked_prefill=True, + compilation_config=vllm.config.CompilationConfig( # Avoid OOM + cudagraph_specialize_lora=False, + ), ) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) @@ -93,22 +96,23 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files): @multi_gpu_test(num_gpus=4) -@create_new_process_for_each_test() def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): # https://github.com/NVIDIA/nccl/issues/1790, set a lower value for # gpu_memory_utilization here because NCCL >= 2.26.3 seems to use # more GPU memory causing vLLM to OOM llm = vllm.LLM( MODEL_PATH, - max_model_len=1024, + max_model_len=512, enable_lora=True, - max_loras=4, + max_loras=2, max_lora_rank=64, tensor_parallel_size=4, trust_remote_code=True, fully_sharded_loras=True, - enable_chunked_prefill=True, - gpu_memory_utilization=0.85, + gpu_memory_utilization=0.8, + compilation_config=vllm.config.CompilationConfig( # Avoid OOM + cudagraph_specialize_lora=False, + ), ) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): diff --git a/tests/lora/test_deepseekv2_tp.py b/tests/lora/test_deepseekv2_tp.py new file mode 100644 index 000000000000..98b7e6333f30 --- /dev/null +++ b/tests/lora/test_deepseekv2_tp.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import vllm +from vllm.lora.request import LoRARequest + +from ..utils import multi_gpu_test + +MODEL_PATH = "deepseek-ai/DeepSeek-V2-Lite-Chat" + +PROMPT_TEMPLATE = "<|begin▁of▁sentence|>You are a helpful assistant.\n\nUser: {context}\n\nAssistant:" # noqa: E501 + + +def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int): + prompts = [ + PROMPT_TEMPLATE.format(context="Who are you?"), + ] + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) + # Print the outputs. + generated_texts: list[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + # return generated_texts + expected_lora_output = [ + "I am \u5f20\u5b50\u8c6a, an AI assistant developed by \u9648\u58eb\u680b.", # noqa: E501 + ] + for i in range(len(expected_lora_output)): + assert generated_texts[i].startswith(expected_lora_output[i]) + + +def test_deepseekv2_lora(deepseekv2_lora_files): + # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, + # Otherwise, the lora-test will fail due to CUDA OOM. + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + enable_chunked_prefill=True, + ) + generate_and_test(llm, deepseekv2_lora_files, 1) + + +def test_deepseekv2(deepseekv2_lora_files): + # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, + # Otherwise, the lora-test will fail due to CUDA OOM. + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + ) + generate_and_test(llm, deepseekv2_lora_files, 1) + + +@multi_gpu_test(num_gpus=2) +def test_deepseekv2_tp2(deepseekv2_lora_files): + # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, + # Otherwise, the lora-test will fail due to CUDA OOM. + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + tensor_parallel_size=2, + ) + generate_and_test(llm, deepseekv2_lora_files, 2) + + +@multi_gpu_test(num_gpus=4) +def test_deepseekv2_tp4(deepseekv2_lora_files): + # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, + # Otherwise, the lora-test will fail due to CUDA OOM. + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + tensor_parallel_size=4, + ) + generate_and_test(llm, deepseekv2_lora_files, 2) diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py new file mode 100644 index 000000000000..f9a66d4d02ea --- /dev/null +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -0,0 +1,289 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.lora.ops.triton_ops import fused_moe_lora +from vllm.platforms import current_platform + + +@pytest.fixture(autouse=True) +def reset_device(reset_default_device): + pass + + +def round_up(x, base): + return ((x + base - 1) // base) * base + + +def CEILDIV(x, y): + return (x + y - 1) // y + + +def assign_loras_to_tokens(num_tokens: int, num_sequences: int, max_loras: int): + """ + Split `num_tokens` into `num_sequences` sequences. + Each sequence randomly selects 1 LoRA index from [0, max_loras), + and all tokens in that sequence are assigned this LoRA index. + + Args: + num_tokens (int): Total number of tokens. + num_sequences (int): Number of sequences to split the tokens into. + max_loras (int): Total number of available LoRA modules. + + Returns: + torch.Tensor: 1D tensor of shape [num_tokens], where each value + is the LoRA index assigned to that token. + """ + assert num_sequences > 0 and max_loras > 0 + assert num_tokens >= num_sequences, "num_tokens must be >= num_sequences" + + # Compute token distribution per sequence (distribute remainder evenly) + tokens_per_seq = num_tokens // num_sequences + remainder = num_tokens % num_sequences + + token_lora_mapping = torch.empty(num_tokens, dtype=torch.int32) + + start = 0 + for seq_idx in range(num_sequences): + # Determine the token range for this sequence + end = start + tokens_per_seq + (1 if seq_idx < remainder else 0) + + # Randomly select one LoRA ID for this sequence + lora_id = random.randint(0, max_loras - 1) + + # Assign the same LoRA ID to all tokens in this sequence + token_lora_mapping[start:end] = lora_id + + start = end + + return token_lora_mapping + + +def assign_experts_to_tokens(num_tokens: int, num_experts: int, top_k_num: int): + """ + For each token, randomly select `top_k_num` distinct experts out of `num_experts`, + and assign normalized random weights that sum to 1. + + Args: + num_tokens (int): Total number of tokens. + num_experts (int): Total number of available experts. + top_k_num (int): Number of experts to select per token. + + Returns: + expert_indices (torch.Tensor): shape [num_tokens, top_k_num], + expert index for each token. + expert_weights (torch.Tensor): shape [num_tokens, top_k_num], + normalized weights (sum = 1 per row). + """ + assert top_k_num <= num_experts, "top_k_num must be <= num_experts" + + # Randomly select top_k_num distinct experts for each token + expert_indices = torch.empty((num_tokens, top_k_num), dtype=torch.int32) + for i in range(num_tokens): + # Randomly choose unique expert indices + selected = torch.randperm(num_experts)[:top_k_num] + expert_indices[i] = selected + + # Generate random weights and normalize along dim=1 + expert_weights = torch.rand((num_tokens, top_k_num), dtype=torch.float32) + expert_weights = expert_weights / expert_weights.sum(dim=1, keepdim=True) + + return expert_indices, expert_weights + + +def sample_data( + num_tokens: int, + num_sequences: int, + max_loras: int, + num_experts: int, + top_k_num: int, +): + topk_ids, topk_weights = assign_experts_to_tokens( + num_tokens, num_experts, top_k_num + ) + token_lora_mapping = assign_loras_to_tokens(num_tokens, num_sequences, max_loras) + return topk_ids, topk_weights, token_lora_mapping + + +def use_fused_moe_lora_kernel( + topk_ids, + topk_weights, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_a_stacked, + lora_b_stacked, + hidden_states, + output, + max_loras, + num_experts, + block_size, +): + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size) + + # init output tensors + sorted_token_ids = torch.empty( + (max_loras * max_num_tokens_padded,), + dtype=torch.int32, + ) + expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32) + num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32) + + # call kernel + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + ) + + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } + + mul_routed_weight = False + expert_ids = expert_ids.view(max_loras, -1) + sorted_token_ids = sorted_token_ids.view(max_loras, -1) + + fused_moe_lora( + output, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + max_lora_rank, + top_k_num, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], + config["GROUP_SIZE_M"], + mul_routed_weight, + ) + + return output + + +def use_torch( + hidden_states, + token_lora_mapping, + topk_ids, + lora_a_stacked, + lora_b_stacked, + top_k_num, +): + outputs = [] + for i in range(hidden_states.shape[0]): + lora_idx = token_lora_mapping[i] + expert_ids = topk_ids[i] + lora_a = lora_a_stacked[0][lora_idx][expert_ids] + lora_b = lora_b_stacked[0][lora_idx][expert_ids] + tensors = [ + hidden_states[i] @ lora_a[x].T @ lora_b[x].T for x in range(top_k_num) + ] + outputs.append(torch.stack(tensors, dim=0)) + return torch.stack(outputs, dim=0) + + +@pytest.mark.parametrize("num_tokens", [100]) +@pytest.mark.parametrize("top_k_num", [6, 12]) +@pytest.mark.parametrize("num_experts", [64]) +@pytest.mark.parametrize("max_loras", [4, 6, 16]) +@pytest.mark.parametrize("N", [1408]) +@pytest.mark.parametrize("K", [2048]) +@pytest.mark.parametrize("max_lora_rank", [16, 32, 64]) +@pytest.mark.parametrize("block_size", [16]) +def test_fused_moe_lora_kernel( + num_tokens, + top_k_num, + num_experts, + max_loras, + N, + K, + max_lora_rank, + block_size, +): + torch.set_default_device("cuda:0") + current_platform.seed_everything(42) + # the number of randomly generated sentences. + num_sequences = 10 + # generate data + topk_ids, topk_weights, token_lora_mapping = sample_data( + num_tokens, num_sequences, max_loras, num_experts, top_k_num + ) + + # init lora weights + lora_a_stacked = [ + torch.rand( + ( + max_loras, + num_experts, + max_lora_rank, + K, + ), + dtype=torch.bfloat16, + ) + ] + lora_b_stacked = [ + torch.rand( + ( + max_loras, + num_experts, + N, + max_lora_rank, + ), + dtype=torch.bfloat16, + ) + ] + hidden_states = torch.rand( + ( + num_tokens, + K, + ), + dtype=torch.bfloat16, + ) + + # fused_moe_lora_kernel output + output = torch.zeros((num_tokens, top_k_num, N), dtype=torch.bfloat16) + use_fused_moe_lora_kernel( + topk_ids, + topk_weights, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_a_stacked, + lora_b_stacked, + hidden_states, + output, + max_loras, + num_experts, + block_size, + ) + # pytorch output + output2 = use_torch( + hidden_states, + token_lora_mapping, + topk_ids, + lora_a_stacked, + lora_b_stacked, + top_k_num, + ) + + torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1) diff --git a/tests/lora/test_gptoss.py b/tests/lora/test_gptoss.py new file mode 100644 index 000000000000..f5c9a5cf20e0 --- /dev/null +++ b/tests/lora/test_gptoss.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "openai/gpt-oss-20b" + +PROMPT_TEMPLATE = "<|begin▁of▁sentence|>You are a helpful assistant.\n\nUser: {context}\n\nAssistant:" # noqa: E501 + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: + prompts = [ + PROMPT_TEMPLATE.format(context="Who are you?"), + ] + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) + # Print the outputs. + generated_texts: list[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +# FIXME: Load gpt-oss adapter +def test_gptoss20b_lora(gptoss20b_lora_files): + # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, + # Otherwise, the lora-test will fail due to CUDA OOM. + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_loras=4, + trust_remote_code=True, + ) + + expected_lora_output = [ + "I am an AI language model developed by OpenAI. " + "I am here to help you with any questions or " + "tasks you may have." + ] + + output1 = do_sample(llm, gptoss20b_lora_files, lora_id=1) + print(output1) + for i in range(len(expected_lora_output)): + assert output1[i].startswith(expected_lora_output[i]) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 695e06e7c1d6..8f18f0144193 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -4,7 +4,6 @@ import random from copy import deepcopy from dataclasses import dataclass -from typing import Optional from unittest.mock import patch import pytest @@ -106,7 +105,7 @@ def skip_cuda_with_stage_false(request): def get_random_id_to_index( num_loras: int, num_slots: int, log: bool = True -) -> list[Optional[int]]: +) -> list[int | None]: """Creates a random lora_id_to_index mapping. Args: @@ -122,7 +121,7 @@ def get_random_id_to_index( "num_loras must be less than or equal to num_slots." ) - slots: list[Optional[int]] = [None] * num_slots + slots: list[int | None] = [None] * num_slots random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist() for lora_id, slot_idx in enumerate(random_slot_selections, start=1): slots[slot_idx] = lora_id @@ -134,7 +133,7 @@ def get_random_id_to_index( def populate_loras( - id_to_index: list[Optional[int]], + id_to_index: list[int | None], layer: BaseLayerWithLoRA, layer_weights: torch.Tensor, generate_embeddings_tensor: int = 0, diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index 0d9431bd7aae..7bbd1e364d19 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -2,9 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import subprocess import sys -from typing import Union + +import pytest import vllm +import vllm.config from vllm import LLM from vllm.lora.request import LoRARequest from vllm.model_executor.model_loader.tensorizer import TensorizerConfig @@ -27,7 +29,7 @@ def do_sample( llm: vllm.LLM, lora_path: str, lora_id: int, - tensorizer_config_dict: Union[dict, None] = None, + tensorizer_config_dict: dict | None = None, ) -> list[str]: prompts = [ "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 @@ -73,9 +75,7 @@ def do_sample( return generated_texts -def generate_and_test( - llm, sql_lora_files, tensorizer_config_dict: Union[dict, None] = None -): +def generate_and_test(llm, sql_lora_files, tensorizer_config_dict: dict | None = None): print("lora adapter created") print("lora 1") assert ( @@ -103,7 +103,8 @@ def generate_and_test( @create_new_process_for_each_test() -def test_llama_lora(sql_lora_files): +@pytest.mark.parametrize("cudagraph_specialize_lora", [True, False]) +def test_llama_lora(sql_lora_files, cudagraph_specialize_lora: bool): llm = vllm.LLM( MODEL_PATH, tokenizer=sql_lora_files, @@ -111,12 +112,14 @@ def test_llama_lora(sql_lora_files): # also test odd max_num_seqs max_num_seqs=13, max_loras=4, + compilation_config=vllm.config.CompilationConfig( + cudagraph_specialize_lora=cudagraph_specialize_lora, + ), ) generate_and_test(llm, sql_lora_files) @multi_gpu_test(num_gpus=4) -@create_new_process_for_each_test() def test_llama_lora_tp4(sql_lora_files): llm = vllm.LLM( MODEL_PATH, @@ -130,7 +133,6 @@ def test_llama_lora_tp4(sql_lora_files): @multi_gpu_test(num_gpus=4) -@create_new_process_for_each_test() def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): llm = vllm.LLM( MODEL_PATH, @@ -145,7 +147,6 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): @multi_gpu_test(num_gpus=2) -@create_new_process_for_each_test() def test_tp2_serialize_and_deserialize_lora( tmp_path, sql_lora_files, sql_lora_huggingface_id ): diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py index ce98fe2f8613..1cf8ed602b6a 100644 --- a/tests/lora/test_minicpmv_tp.py +++ b/tests/lora/test_minicpmv_tp.py @@ -8,7 +8,7 @@ from vllm.lora.request import LoRARequest from vllm.platforms import current_platform -from ..utils import create_new_process_for_each_test +from ..utils import multi_gpu_test MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" @@ -88,7 +88,7 @@ def test_minicpmv_lora(minicpmv_lora_files): current_platform.is_rocm(), reason="MiniCPM-V dependency xformers incompatible with ROCm", ) -@create_new_process_for_each_test() +@multi_gpu_test(num_gpus=4) def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): llm = vllm.LLM( MODEL_PATH, @@ -112,7 +112,7 @@ def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): current_platform.is_rocm(), reason="MiniCPM-V dependency xformers incompatible with ROCm", ) -@create_new_process_for_each_test() +@multi_gpu_test(num_gpus=4) def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files): llm = vllm.LLM( MODEL_PATH, diff --git a/tests/lora/test_moe_lora_align_sum.py b/tests/lora/test_moe_lora_align_sum.py new file mode 100644 index 000000000000..6cd1281c3632 --- /dev/null +++ b/tests/lora/test_moe_lora_align_sum.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random + +import pytest +import torch + +from vllm import _custom_ops as ops + + +def round_up(x, base): + return ((x + base - 1) // base) * base + + +def CEILDIV(x, y): + return (x + y - 1) // y + + +def sample_data(num_experts, max_loras, num_tokens, topk_num): + topk_ids = torch.zeros((num_tokens, topk_num), dtype=torch.int32) + token_lora_mapping = torch.zeros((num_tokens,), dtype=torch.int32) + + for i in range(num_tokens): + pool = list(range(num_experts)) + random.shuffle(pool) + for j in range(topk_num): + topk_ids[i, j] = pool[j] + token_lora_mapping[i] = random.randint(0, max_loras - 1) + + return topk_ids.to("cuda"), token_lora_mapping.to("cuda") + + +@pytest.mark.parametrize("num_tokens", [100, 200, 1024, 4096]) # 81920 +@pytest.mark.parametrize("topk_num", [6]) +@pytest.mark.parametrize("num_experts", [64, 128]) +@pytest.mark.parametrize("max_loras", [2, 32]) +@pytest.mark.parametrize("block_size", [16]) +def test_moe_lora_align_block_size( + num_tokens, topk_num, num_experts, max_loras, block_size +): + # sample data + random.seed(1) + topk_ids, token_lora_mapping = sample_data( + num_experts, max_loras, num_tokens, topk_num + ) + + # compute paddings + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size) + + # init output tensors + sorted_token_ids = torch.full( + (max_loras * max_num_tokens_padded,), + topk_ids.numel(), + dtype=torch.int32, + device="cuda", + ) + expert_ids = torch.full( + (max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda" + ) + num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda") + + # call kernel + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_token_ids, + expert_ids, + num_tokens_post_pad, + ) + + # verify values + expert_ids = expert_ids.view(max_loras, -1) + sorted_token_ids = sorted_token_ids.view(max_loras, -1, block_size) + + for lora_idx in range(max_loras): + for token_idx in range(sorted_token_ids.size(1)): + block = sorted_token_ids[lora_idx][token_idx] + indices = block[block != topk_ids.numel()] + if indices.numel() > 0: + expert_id = expert_ids[lora_idx][token_idx] + assert torch.all(topk_ids.view(-1)[indices] == expert_id) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/lora/test_olmoe_tp.py b/tests/lora/test_olmoe_tp.py new file mode 100644 index 000000000000..b954e0776ca4 --- /dev/null +++ b/tests/lora/test_olmoe_tp.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import vllm +from vllm.lora.request import LoRARequest + +from ..utils import multi_gpu_test + +MODEL_PATH = "allenai/OLMoE-1B-7B-0125-Instruct" + +PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request. +" +##Instruction: +candidate_poll contains tables such as candidate, people. Table candidate has columns such as Candidate_ID, People_ID, Poll_Source, Date, Support_rate, Consider_rate, Oppose_rate, Unsure_rate. Candidate_ID is the primary key. +Table people has columns such as People_ID, Sex, Name, Date_of_Birth, Height, Weight. People_ID is the primary key. +The People_ID of candidate is the foreign key of People_ID of people. + + +###Input: +{context} + +###Response:""" # noqa: E501 + +EXPECTED_LORA_OUTPUT = [ + "SELECT count(*) FROM candidate", + "SELECT count(*) FROM candidate", + "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 + "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 +] + + +def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None: + prompts = [ + PROMPT_TEMPLATE.format(context="How many candidates are there?"), + PROMPT_TEMPLATE.format(context="Count the number of candidates."), + PROMPT_TEMPLATE.format( + context="Which poll resource provided the most number of candidate information?" # noqa: E501 + ), + PROMPT_TEMPLATE.format( + context="Return the poll resource associated with the most candidates." + ), + ] + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) + # Print the outputs. + generated_texts: list[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i]) + + +def test_olmoe_lora(olmoe_lora_files): + # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, + # Otherwise, the lora-test will fail due to CUDA OOM. + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + enable_chunked_prefill=True, + ) + + generate_and_test(llm, olmoe_lora_files, lora_id=1) + generate_and_test(llm, olmoe_lora_files, lora_id=2) + + +@multi_gpu_test(num_gpus=2) +def test_olmoe_lora_tp2(olmoe_lora_files): + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + enable_chunked_prefill=True, + tensor_parallel_size=2, + ) + + generate_and_test(llm, olmoe_lora_files, lora_id=1) + generate_and_test(llm, olmoe_lora_files, lora_id=2) + + +@multi_gpu_test(num_gpus=4) +def test_olmoe_lora_tp4(olmoe_lora_files): + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + enable_chunked_prefill=True, + tensor_parallel_size=4, + ) + + generate_and_test(llm, olmoe_lora_files, lora_id=1) + generate_and_test(llm, olmoe_lora_files, lora_id=2) diff --git a/tests/lora/test_peft_helper.py b/tests/lora/test_peft_helper.py index 2cc8bfe63495..9c55c623d444 100644 --- a/tests/lora/test_peft_helper.py +++ b/tests/lora/test_peft_helper.py @@ -16,11 +16,6 @@ {"r": 1024}, "is greater than max_lora_rank", ), - ( - "test_bias", - {"bias": "all"}, - "Adapter bias cannot be used without bias_enabled", - ), ("test_dora", {"use_dora": True}, "does not yet support DoRA"), ( "test_modules_to_save", diff --git a/tests/lora/test_qwen2vl.py b/tests/lora/test_qwen2vl.py index 894263bd0ba3..1800ca107a42 100644 --- a/tests/lora/test_qwen2vl.py +++ b/tests/lora/test_qwen2vl.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional import pytest @@ -20,7 +19,7 @@ class TestConfig: max_loras: int = 2 max_lora_rank: int = 16 max_model_len: int = 4096 - mm_processor_kwargs: Optional[dict[str, int]] = None + mm_processor_kwargs: dict[str, int] | None = None def __post_init__(self): if self.mm_processor_kwargs is None: @@ -61,7 +60,7 @@ def run_test( self, images: list[ImageAsset], expected_outputs: list[str], - lora_id: Optional[int] = None, + lora_id: int | None = None, temperature: float = 0, max_tokens: int = 5, ): @@ -92,7 +91,7 @@ def run_beam_search_test( self, images: list[ImageAsset], expected_outputs: list[list[str]], - lora_id: Optional[int] = None, + lora_id: int | None = None, temperature: float = 0, beam_width: int = 2, max_tokens: int = 5, diff --git a/tests/lora/test_qwen3moe_tp.py b/tests/lora/test_qwen3moe_tp.py new file mode 100644 index 000000000000..de2b040907f9 --- /dev/null +++ b/tests/lora/test_qwen3moe_tp.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import vllm +from vllm.lora.request import LoRARequest + +from ..utils import multi_gpu_test + +MODEL_PATH = "Qwen/Qwen3-30B-A3B" + +PROMPT_TEMPLATE = """<|im_start|>user +I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request. +" +##Instruction: +candidate_poll contains tables such as candidate, people. Table candidate has columns such as Candidate_ID, People_ID, Poll_Source, Date, Support_rate, Consider_rate, Oppose_rate, Unsure_rate. Candidate_ID is the primary key. +Table people has columns such as People_ID, Sex, Name, Date_of_Birth, Height, Weight. People_ID is the primary key. +The People_ID of candidate is the foreign key of People_ID of people. + + +###Input: +{context} + +###Response:<|im_end|> +<|im_start|>assistant""" # noqa: E501 + +EXPECTED_LORA_OUTPUT = [ + "<think>\n\n</think>\n\nSELECT count(*) FROM candidate", + "<think>\n\n</think>\n\nSELECT count(*) FROM candidate", + "<think>\n\n</think>\n\nSELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 + "<think>\n\n</think>\n\nSELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 +] + + +def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None: + prompts = [ + PROMPT_TEMPLATE.format(context="How many candidates are there?"), + PROMPT_TEMPLATE.format(context="Count the number of candidates."), + PROMPT_TEMPLATE.format( + context="Which poll resource provided the most number of candidate information?" # noqa: E501 + ), + PROMPT_TEMPLATE.format( + context="Return the poll resource associated with the most candidates." + ), + ] + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) + # Print the outputs. + generated_texts: list[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i]) + + +def test_qwen3moe_lora(qwen3moe_lora_files): + # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, + # Otherwise, the lora-test will fail due to CUDA OOM. + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + enable_chunked_prefill=True, + ) + + generate_and_test(llm, qwen3moe_lora_files, lora_id=1) + generate_and_test(llm, qwen3moe_lora_files, lora_id=2) + + +@multi_gpu_test(num_gpus=2) +def test_qwen3moe_lora_tp2(qwen3moe_lora_files): + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + enable_chunked_prefill=True, + tensor_parallel_size=2, + ) + + generate_and_test(llm, qwen3moe_lora_files, lora_id=1) + generate_and_test(llm, qwen3moe_lora_files, lora_id=2) + + +@multi_gpu_test(num_gpus=4) +def test_qwen3moe_lora_tp4(qwen3moe_lora_files): + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + enable_chunked_prefill=True, + tensor_parallel_size=4, + ) + + generate_and_test(llm, qwen3moe_lora_files, lora_id=1) + generate_and_test(llm, qwen3moe_lora_files, lora_id=2) diff --git a/tests/lora/test_resolver.py b/tests/lora/test_resolver.py index c70e58a375c7..9b5dedc4327f 100644 --- a/tests/lora/test_resolver.py +++ b/tests/lora/test_resolver.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import pytest @@ -14,7 +13,7 @@ class DummyLoRAResolver(LoRAResolver): async def resolve_lora( self, base_model_name: str, lora_name: str - ) -> Optional[LoRARequest]: + ) -> LoRARequest | None: if lora_name == "test_lora": return LoRARequest( lora_name=lora_name, diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py index aed91d98ddbd..eb026c2ec020 100644 --- a/tests/lora/test_utils.py +++ b/tests/lora/test_utils.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import OrderedDict -from typing import NamedTuple, Optional +from typing import NamedTuple from unittest.mock import patch import pytest @@ -21,8 +21,7 @@ class LoRANameParserTestConfig(NamedTuple): name: str module_name: str is_lora_a: bool - is_bias: bool - weights_mapper: Optional[WeightsMapper] = None + weights_mapper: WeightsMapper | None = None def test_parse_fine_tuned_lora_name_valid(): @@ -37,44 +36,37 @@ def test_parse_fine_tuned_lora_name_valid(): "base_model.model.model.embed_tokens.lora_embedding_A", "model.embed_tokens", True, - False, ), LoRANameParserTestConfig( "base_model.model.model.embed_tokens.lora_embedding_B", "model.embed_tokens", False, - False, ), LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "model.layers.9.mlp.down_proj", True, - False, ), LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "model.layers.9.mlp.down_proj", False, - False, ), LoRANameParserTestConfig( "language_model.layers.9.mlp.down_proj.lora_A.weight", "language_model.layers.9.mlp.down_proj", True, - False, ), LoRANameParserTestConfig( "language_model.layers.9.mlp.down_proj.lora_B.weight", "language_model.layers.9.mlp.down_proj", False, - False, ), # Test with WeightsMapper LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "language_model.model.layers.9.mlp.down_proj", True, - False, weights_mapper=WeightsMapper( orig_to_new_prefix={"model.": "language_model.model."} ), @@ -83,7 +75,6 @@ def test_parse_fine_tuned_lora_name_valid(): "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "language_model.model.layers.9.mlp.down_proj", False, - False, weights_mapper=WeightsMapper( orig_to_new_prefix={"model.": "language_model.model."} ), @@ -92,7 +83,6 @@ def test_parse_fine_tuned_lora_name_valid(): "model.layers.9.mlp.down_proj.lora_A.weight", "language_model.model.layers.9.mlp.down_proj", True, - False, weights_mapper=WeightsMapper( orig_to_new_prefix={"model.": "language_model.model."} ), @@ -101,14 +91,13 @@ def test_parse_fine_tuned_lora_name_valid(): "model.layers.9.mlp.down_proj.lora_B.weight", "language_model.model.layers.9.mlp.down_proj", False, - False, weights_mapper=WeightsMapper( orig_to_new_prefix={"model.": "language_model.model."} ), ), ] - for name, module_name, is_lora_a, is_bias, weights_mapper in fixture: - assert (module_name, is_lora_a, is_bias) == parse_fine_tuned_lora_name( + for name, module_name, is_lora_a, weights_mapper in fixture: + assert (module_name, is_lora_a) == parse_fine_tuned_lora_name( name, weights_mapper ) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index b522aa6b0874..d30b77f09466 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -4,7 +4,6 @@ import json import os from dataclasses import dataclass -from typing import Optional, Union import torch from safetensors.torch import save_file @@ -81,7 +80,7 @@ def init_packed_lora( module_name: str, input_dim: int, output_dims: list[int], - noop_lora_index: Optional[list[int]] = None, + noop_lora_index: list[int] | None = None, rank: int = 8, ): base_loras: list[LoRALayerWeights] = [] @@ -113,7 +112,7 @@ def assert_close(a, b): @dataclass class PunicaTensors: inputs_tensor: torch.Tensor - lora_weights: Union[torch.Tensor, list[torch.Tensor]] + lora_weights: torch.Tensor | list[torch.Tensor] our_out_tensor: torch.Tensor ref_out_tensor: torch.Tensor b_seq_start_loc: torch.Tensor diff --git a/tests/model_executor/model_loader/tensorizer_loader/conftest.py b/tests/model_executor/model_loader/tensorizer_loader/conftest.py index add6d3742ff5..826ecec71e6c 100644 --- a/tests/model_executor/model_loader/tensorizer_loader/conftest.py +++ b/tests/model_executor/model_loader/tensorizer_loader/conftest.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable +from collections.abc import Callable import pytest @@ -8,8 +8,8 @@ from vllm.distributed import cleanup_dist_env_and_memory from vllm.model_executor.model_loader import tensorizer as tensorizer_mod from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.utils import get_distributed_init_method, get_ip, get_open_port -from vllm.v1.executor.abstract import UniProcExecutor +from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port +from vllm.v1.executor import UniProcExecutor from vllm.v1.worker.worker_base import WorkerWrapperBase MODEL_REF = "facebook/opt-125m" diff --git a/tests/model_executor/model_loader/tensorizer_loader/test_tensorizer.py b/tests/model_executor/model_loader/tensorizer_loader/test_tensorizer.py index 57db1f98baed..ed5129e1c820 100644 --- a/tests/model_executor/model_loader/tensorizer_loader/test_tensorizer.py +++ b/tests/model_executor/model_loader/tensorizer_loader/test_tensorizer.py @@ -27,7 +27,7 @@ from vllm.model_executor.model_loader.tensorizer_loader import ( BLACKLISTED_TENSORIZER_ARGS, ) -from vllm.utils import PlaceholderModule +from vllm.utils.import_utils import PlaceholderModule from .conftest import DummyExecutor, assert_from_collective_rpc diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index ab3a3a8268a3..41419553aa83 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import pytest import torch @@ -37,7 +36,7 @@ class Relu3(ReLUSquaredActivation): @pytest.mark.parametrize( - "env, torch_level, backend, ops_enabled, default_on", + "env, compilation_mode, backend, ops_enabled, default_on", [ # Default values based on compile level # - All by default (no Inductor compilation) @@ -77,8 +76,8 @@ class Relu3(ReLUSquaredActivation): ], ) def test_enabled_ops( - env: Optional[str], - torch_level: int, + env: str | None, + compilation_mode: int, backend: str, ops_enabled: list[int], default_on: bool, @@ -86,10 +85,9 @@ def test_enabled_ops( custom_ops = env.split(",") if env else [] vllm_config = VllmConfig( compilation_config=CompilationConfig( - backend=backend, level=torch_level, custom_ops=custom_ops + backend=backend, mode=compilation_mode, custom_ops=custom_ops ) ) - # breakpoint() with set_current_vllm_config(vllm_config): assert CustomOp.default_on() == default_on diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index 3fc265194e2a..ad37d1ad82c0 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import pytest import torch @@ -100,7 +99,7 @@ "allenai/OLMoE-1B-7B-0924-Instruct", marks=[pytest.mark.cpu_model], ), - pytest.param("swiss-ai/Apertus-8B-2509"), # apertus + pytest.param("swiss-ai/Apertus-8B-Instruct-2509"), # apertus ], ) @pytest.mark.parametrize("max_tokens", [32]) @@ -138,7 +137,7 @@ def test_models( example_prompts, max_tokens, num_logprobs ) - prompt_embeds: Optional[list[torch.Tensor]] = [] if use_prompt_embeds else None + prompt_embeds: list[torch.Tensor] | None = [] if use_prompt_embeds else None prompt_token_ids = [] for prompt in example_prompts: diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index abedd15b0d7e..fd2df329f17f 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable +from collections.abc import Callable import pytest diff --git a/tests/models/language/generation_ppl_test/ppl_utils.py b/tests/models/language/generation_ppl_test/ppl_utils.py index dcef365e99e7..59740505e827 100644 --- a/tests/models/language/generation_ppl_test/ppl_utils.py +++ b/tests/models/language/generation_ppl_test/ppl_utils.py @@ -1,14 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from https://huggingface.co/docs/transformers/perplexity -from typing import Optional, cast +from typing import cast -import pytest import torch from datasets import load_dataset import tests.ci_envs as ci_envs -from tests.models.utils import GenerateModelInfo, TokensTextLogprobsPromptLogprobs +from tests.models.utils import ( + GenerateModelInfo, + TokensTextLogprobsPromptLogprobs, + get_vllm_extra_kwargs, +) from vllm.logprobs import Logprob # See #24485 @@ -25,33 +28,15 @@ def wikitext_ppl_test( vllm_extra_kwargs=None, atol=PPL_TOL, ): - # A model family has many models with the same architecture, - # and we don't need to test each one. - if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: - pytest.skip("Skipping test.") + vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs) dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") - # Allow vllm to test using the given dtype, such as float32 - vllm_extra_kwargs = vllm_extra_kwargs or {} - vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype - - # Allow vllm to test using hf_overrides - if model_info.hf_overrides is not None: - vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides - - # Allow changing the head dtype used by vllm in tests - if ci_envs.VLLM_CI_HEAD_DTYPE is not None: - if "hf_overrides" not in vllm_extra_kwargs: - vllm_extra_kwargs["hf_overrides"] = {} - vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE - with vllm_runner( model_info.name, gpu_memory_utilization=0.7, max_model_len=max_length, max_num_seqs=1, - enforce_eager=True, **vllm_extra_kwargs, ) as vllm_model: # Use max_num_seqs=1 to avoid OOM, @@ -86,7 +71,7 @@ def wikitext_ppl_test( n_tokens = 0 for output in outputs: output = cast(TokensTextLogprobsPromptLogprobs, output) - token_datas = cast(list[Optional[dict[int, Logprob]]], output[3]) + token_datas = cast(list[dict[int, Logprob] | None], output[3]) assert token_datas[0] is None token_log_probs = [] diff --git a/tests/models/language/pooling/embed_utils.py b/tests/models/language/pooling/embed_utils.py index 261ab80ae86b..4ac40656bc62 100644 --- a/tests/models/language/pooling/embed_utils.py +++ b/tests/models/language/pooling/embed_utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import Optional import pytest @@ -13,7 +12,7 @@ def run_embedding_correctness_test( hf_model: "HfRunner", inputs: list[str], vllm_outputs: Sequence[list[float]], - dimensions: Optional[int] = None, + dimensions: int | None = None, ): hf_outputs = hf_model.encode(inputs) if dimensions: diff --git a/tests/models/language/pooling/test_auto_prefix_cache_support.py b/tests/models/language/pooling/test_auto_prefix_cache_support.py index e95119df95c7..0904c7e877ef 100644 --- a/tests/models/language/pooling/test_auto_prefix_cache_support.py +++ b/tests/models/language/pooling/test_auto_prefix_cache_support.py @@ -19,14 +19,25 @@ def test_classify_models( model: str, dtype: str, ) -> None: - example_prompts = example_prompts * 2 + # example_prompts is too short for testing prefix_caching + example_prompts = [s * 10 for s in example_prompts] with vllm_runner( model, max_model_len=512, dtype=dtype, enable_prefix_caching=True ) as vllm_model: cache_config = vllm_model.llm.llm_engine.cache_config assert cache_config.enable_prefix_caching - vllm_outputs = vllm_model.classify(example_prompts) + + # First Run + vllm_model.classify(example_prompts) + + # assert prefix_caching works + pooling_outputs = vllm_model.llm.encode( + example_prompts, pooling_task="classify" + ) + for output in pooling_outputs: + assert output.num_cached_tokens > 0 + vllm_outputs = [req_output.outputs.data for req_output in pooling_outputs] with hf_runner( model, dtype=dtype, auto_cls=AutoModelForSequenceClassification @@ -54,7 +65,8 @@ def test_embed_models( model: str, dtype: str, ): - example_prompts = [str(s).strip() for s in example_prompts] * 2 + # example_prompts is too short for testing prefix_caching + example_prompts = [str(s).strip() * 10 for s in example_prompts] with vllm_runner( model, @@ -64,7 +76,15 @@ def test_embed_models( ) as vllm_model: cache_config = vllm_model.llm.llm_engine.cache_config assert cache_config.enable_prefix_caching - vllm_outputs = vllm_model.embed(example_prompts) + + # First Run + vllm_model.embed(example_prompts) + + # assert prefix_caching works + pooling_outputs = vllm_model.llm.encode(example_prompts, pooling_task="embed") + for output in pooling_outputs: + assert output.num_cached_tokens > 0 + vllm_outputs = [req_output.outputs.data for req_output in pooling_outputs] with hf_runner( model, diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index c9574dca498e..c8deffbf66db 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import pytest @@ -66,7 +65,7 @@ def test_models( pooling_type="MEAN", normalize=False ) - max_model_len: Optional[int] = 512 + max_model_len: int | None = 512 if model in [ "sentence-transformers/all-MiniLM-L12-v2", "sentence-transformers/stsb-roberta-base-v2", diff --git a/tests/models/language/pooling/test_extract_hidden_states.py b/tests/models/language/pooling/test_extract_hidden_states.py new file mode 100644 index 000000000000..f8e3fa7d1560 --- /dev/null +++ b/tests/models/language/pooling/test_extract_hidden_states.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from vllm import TokensPrompt + + +@pytest.mark.parametrize( + "model", + ["Qwen/Qwen3-0.6B"], +) +@torch.inference_mode +def test_embed_models(hf_runner, vllm_runner, model: str): + n_prompt_tokens = [55, 56, 57] + token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens] + + with vllm_runner( + model, + max_model_len=128, + enforce_eager=True, + runner="pooling", + enable_chunked_prefill=False, + enable_prefix_caching=False, + ) as vllm_model: + pooling_outputs = vllm_model.llm.encode( + [TokensPrompt(prompt_token_ids=t) for t in token_prompts], + pooling_task="token_embed", + ) + + for n, output in zip(n_prompt_tokens, pooling_outputs): + assert len(output.prompt_token_ids) == n + assert output.num_cached_tokens == 0 diff --git a/tests/models/language/pooling/test_gritlm.py b/tests/models/language/pooling/test_gritlm.py index 14308ac06c03..0adc9b5cf25f 100644 --- a/tests/models/language/pooling/test_gritlm.py +++ b/tests/models/language/pooling/test_gritlm.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import numpy as np import openai import pytest diff --git a/tests/models/language/pooling/test_head_dtype.py b/tests/models/language/pooling/test_head_dtype.py new file mode 100644 index 000000000000..b60d4dade49a --- /dev/null +++ b/tests/models/language/pooling/test_head_dtype.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from transformers import AutoModelForSequenceClassification + + +@pytest.mark.parametrize( + "model", + ["nie3e/sentiment-polish-gpt2-small"], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_classify_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with hf_runner( + model, dtype=dtype, auto_cls=AutoModelForSequenceClassification + ) as hf_model: + hf_outputs = hf_model.classify(example_prompts) + + for head_dtype_str in ["float32", "model"]: + with vllm_runner( + model, + max_model_len=512, + dtype=dtype, + hf_overrides={"head_dtype": head_dtype_str}, + ) as vllm_model: + model_config = vllm_model.llm.llm_engine.model_config + model_dtype = model_config.dtype + head_dtype = model_config.head_dtype + + if head_dtype_str == "float32": + assert head_dtype == torch.float32 + elif head_dtype_str == "model": + assert head_dtype == model_dtype + + vllm_outputs = vllm_model.classify(example_prompts) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output).float() + vllm_output = torch.tensor(vllm_output).float() + + assert torch.allclose(hf_output, vllm_output, atol=1e-2) diff --git a/tests/models/language/pooling/test_multi_vector_retrieval.py b/tests/models/language/pooling/test_multi_vector_retrieval.py new file mode 100644 index 000000000000..302f2df13557 --- /dev/null +++ b/tests/models/language/pooling/test_multi_vector_retrieval.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from transformers import AutoModel + +from tests.models.utils import check_embeddings_close + + +@pytest.mark.parametrize( + "model", + ["BAAI/bge-m3"], +) +@pytest.mark.parametrize("dtype", ["half"]) +@torch.inference_mode +def test_embed_models(hf_runner, vllm_runner, example_prompts, model: str, dtype: str): + with vllm_runner( + model, + runner="pooling", + max_model_len=None, + ) as vllm_model: + vllm_outputs = vllm_model.token_embed(example_prompts) + + with hf_runner( + model, + auto_cls=AutoModel, + ) as hf_model: + tokenizer = hf_model.tokenizer + hf_outputs = [] + for prompt in example_prompts: + inputs = tokenizer([prompt], return_tensors="pt") + inputs = hf_model.wrap_device(inputs) + output = hf_model.model(**inputs) + embedding = output.last_hidden_state[0].float() + # normal + hf_outputs.append(embedding.cpu()) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + check_embeddings_close( + embeddings_0_lst=hf_output, + embeddings_1_lst=vllm_output, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) diff --git a/tests/models/language/pooling/test_pooler_config_init_behaviour.py b/tests/models/language/pooling/test_pooler_config_init_behaviour.py index 674bf02b7b98..55663ee3f1b4 100644 --- a/tests/models/language/pooling/test_pooler_config_init_behaviour.py +++ b/tests/models/language/pooling/test_pooler_config_init_behaviour.py @@ -93,7 +93,7 @@ def test_embed_models_using_normalize( ], ) @pytest.mark.parametrize("dtype", ["half"]) -def test_reward_models_using_softmax( +def test_reward_models_using_activation( hf_runner, vllm_runner, example_prompts, @@ -104,22 +104,64 @@ def test_reward_models_using_softmax( model, max_model_len=1024, dtype=dtype, - pooler_config=PoolerConfig(softmax=False), + pooler_config=PoolerConfig(activation=False), ) as vllm_model: - wo_softmax = vllm_model.encode(example_prompts) + wo_activation = vllm_model.reward(example_prompts) with vllm_runner( - model, max_model_len=1024, dtype=dtype, pooler_config=PoolerConfig(softmax=True) + model, + max_model_len=1024, + dtype=dtype, + pooler_config=PoolerConfig(activation=True), ) as vllm_model: - w_softmax = vllm_model.encode(example_prompts) + w_activation = vllm_model.reward(example_prompts) - for wo, w in zip(wo_softmax, w_softmax): + for wo, w in zip(wo_activation, w_activation): wo = torch.tensor(wo) w = torch.tensor(w) assert not torch.allclose(wo, w, atol=1e-2), ( - "pooler_config softmax is not working" + "pooler_config activation is not working" ) assert torch.allclose(softmax(wo), w, atol=1e-2), ( - "w_softmax should be close to softmax(wo_softmax)." + "w_activation should be close to activation(wo_activation)." + ) + + +@pytest.mark.parametrize( + "model", + [ + "intfloat/multilingual-e5-small", + ], +) +@pytest.mark.parametrize("dtype", ["half"]) +def test_multi_vector_retrieval_models_using_normalize( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with vllm_runner( + model, + max_model_len=512, + dtype=dtype, + pooler_config=PoolerConfig(normalize=False), + ) as vllm_model: + wo_normalize = vllm_model.token_embed(example_prompts) + + with vllm_runner( + model, + max_model_len=512, + dtype=dtype, + pooler_config=PoolerConfig(normalize=True), + ) as vllm_model: + w_normalize = vllm_model.token_embed(example_prompts) + + for wo, w in zip(wo_normalize, w_normalize): + assert not torch.allclose(wo, w, atol=1e-2), ( + "pooler_config normalize is not working" + ) + assert torch.allclose(F.normalize(wo, p=2, dim=-1), w, atol=1e-2), ( + "w_normal should be close to normal(wo_normal)." ) diff --git a/tests/models/language/pooling/test_splade_sparse_pooler.py b/tests/models/language/pooling/test_splade_sparse_pooler.py new file mode 100644 index 000000000000..af4fd764ef53 --- /dev/null +++ b/tests/models/language/pooling/test_splade_sparse_pooler.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import types + +import pytest +import torch +import torch.nn as nn + +from vllm.model_executor.models.bert import ( + BertMLMHead, + SPLADESparsePooler, +) + +# --------------------------------------------------------------------- +# Functional test: SPLADE formula correctness (no HF download needed) +# --------------------------------------------------------------------- + + +@pytest.mark.parametrize("B,T,H,V", [(2, 3, 5, 7)]) +@torch.inference_mode +def test_splade_pooler_matches_reference_formula(B, T, H, V): + """Ensure SPLADESparsePooler forward() matches the mathematical formula: + log1p(relu(logits)) -> max over sequence length (after masking).""" + torch.manual_seed(0) + + # Prepare [B] sequences of shape [T, H] + hs_list = [torch.randn(T, H) for _ in range(B)] + hs_tenser = torch.cat(hs_list) + + # Simulate PoolingMetadata (only required fields) + prompt_lens = [T, T - 1] + prompt_lens_tenser = torch.tensor(prompt_lens, dtype=torch.int32) + token_ids = torch.tensor( + [ + [101, 5, 102], # Batch 0: [CLS], token, [SEP] + [101, 6, 6], # Batch 1: [CLS], token, token (last token ignored) + ], + dtype=torch.long, + ) + meta = types.SimpleNamespace( + prompt_lens=prompt_lens_tenser, prompt_token_ids=token_ids + ) + + # MLM head (prefer BertMLMHead, fallback to Linear if unavailable) + try: + mlm_head = BertMLMHead(hidden_size=H, vocab_size=V, layer_norm_eps=1e-12) + except Exception: + mlm_head = nn.Linear(H, V, bias=True) + + # Forward pass through SPLADE pooler + pooler = SPLADESparsePooler(mlm_head=mlm_head, pooling="max", remove_cls_sep=True) + pooled = pooler(hidden_states=hs_tenser, pooling_metadata=meta) # list of [V] + + # Basic output checks + assert isinstance(pooled, torch.Tensor) and len(pooled) == B + for vec in pooled: + assert vec.shape == (V,) + assert torch.isfinite(vec).all() + assert (vec >= 0).all(), "SPLADE outputs must be non-negative." + + # Reference implementation for comparison + def ref_one(hs: torch.Tensor, L: int, tid_row: torch.Tensor) -> torch.Tensor: + keep = torch.ones(L, dtype=torch.bool) + if L > 0 and tid_row[0].item() == 101: # remove CLS + keep[0] = False + if L > 0 and tid_row[L - 1].item() == 102: # remove SEP + keep[L - 1] = False + + valid = hs[:L][keep[:L]] + if valid.numel() == 0: + return torch.zeros(V, dtype=torch.float32) + + logits = mlm_head(valid) # [L', V] + scores = torch.log1p(torch.relu(logits)) # [L', V] + return scores.max(dim=0).values.to(torch.float32) + + torch.testing.assert_close( + pooled[0], + ref_one(hs_list[0], prompt_lens[0], token_ids[0]), + rtol=1e-4, + atol=1e-4, + ) + torch.testing.assert_close( + pooled[1], + ref_one(hs_list[1], prompt_lens[1], token_ids[1]), + rtol=1e-4, + atol=1e-4, + ) diff --git a/tests/models/language/pooling/test_token_classification.py b/tests/models/language/pooling/test_token_classification.py index f72dfb46d9fd..2dfc0072126b 100644 --- a/tests/models/language/pooling/test_token_classification.py +++ b/tests/models/language/pooling/test_token_classification.py @@ -19,7 +19,7 @@ def test_bert_models( dtype: str, ) -> None: with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.encode(example_prompts) + vllm_outputs = vllm_model.token_classify(example_prompts) with hf_runner( model, dtype=dtype, auto_cls=AutoModelForTokenClassification @@ -50,7 +50,7 @@ def test_modernbert_models( dtype: str, ) -> None: with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.encode(example_prompts) + vllm_outputs = vllm_model.token_classify(example_prompts) with hf_runner( model, dtype=dtype, auto_cls=AutoModelForTokenClassification @@ -67,4 +67,4 @@ def test_modernbert_models( for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): hf_output = torch.tensor(hf_output).cpu().float() vllm_output = torch.tensor(vllm_output).cpu().float() - assert torch.allclose(hf_output, vllm_output, 1e-2) + assert torch.allclose(hf_output, vllm_output, atol=1e-2) diff --git a/tests/models/language/pooling_mteb_test/mteb_utils.py b/tests/models/language/pooling_mteb_test/mteb_utils.py index a4a7f1b48d3d..ee1a1ca8d653 100644 --- a/tests/models/language/pooling_mteb_test/mteb_utils.py +++ b/tests/models/language/pooling_mteb_test/mteb_utils.py @@ -3,16 +3,19 @@ import tempfile from collections.abc import Sequence -from typing import Optional import mteb import numpy as np -import pytest import requests import torch import tests.ci_envs as ci_envs -from tests.models.utils import EmbedModelInfo, RerankModelInfo, check_embeddings_close +from tests.models.utils import ( + EmbedModelInfo, + RerankModelInfo, + check_embeddings_close, + get_vllm_extra_kwargs, +) # Most embedding models on the STS12 task (See #17175): # - Model implementation and minor changes in tensor dtype @@ -51,7 +54,7 @@ def encode( def predict( self, - sentences: list[tuple[str, str, Optional[str]]], # query, corpus, prompt + sentences: list[tuple[str, str, str | None]], # query, corpus, prompt *args, **kwargs, ) -> np.ndarray: @@ -100,7 +103,7 @@ def __init__(self, model_name: str, url): def predict( self, - sentences: list[tuple[str, str, Optional[str]]], # query, corpus, prompt + sentences: list[tuple[str, str, str | None]], # query, corpus, prompt *args, **kwargs, ) -> np.ndarray: @@ -166,33 +169,15 @@ def mteb_test_embed_models( hf_model_callback=None, atol=MTEB_EMBED_TOL, ): - # A model family has many models with the same architecture, - # and we don't need to test each one. - if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: - pytest.skip("Skipping test.") + vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs) # Test embed_dims, isnan and whether to use normalize example_prompts = ["The chef prepared a delicious meal." * 1000] - # Allow vllm to test using the given dtype, such as float32 - vllm_extra_kwargs = vllm_extra_kwargs or {} - vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype - - # Allow vllm to test using hf_overrides - if model_info.hf_overrides is not None: - vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides - - # Allow changing the head dtype used by vllm in tests - if ci_envs.VLLM_CI_HEAD_DTYPE is not None: - if "hf_overrides" not in vllm_extra_kwargs: - vllm_extra_kwargs["hf_overrides"] = {} - vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE - with vllm_runner( model_info.name, runner="pooling", max_model_len=None, - enforce_eager=True, **vllm_extra_kwargs, ) as vllm_model: model_config = vllm_model.llm.llm_engine.model_config @@ -214,9 +199,12 @@ def mteb_test_embed_models( vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype head_dtype = model_config.head_dtype - # Test embed_dims, isnan and whether to use normalize + # Test embedding_size, isnan and whether to use normalize vllm_outputs = vllm_model.embed(example_prompts, truncate_prompt_tokens=-1) - assert not torch.any(torch.isnan(torch.tensor(vllm_outputs))) + outputs_tensor = torch.tensor(vllm_outputs) + assert not torch.any(torch.isnan(outputs_tensor)) + embedding_size = model_config.embedding_size + assert torch.tensor(vllm_outputs).shape[-1] == embedding_size # Accelerate mteb test by setting # SentenceTransformers mteb score to a constant @@ -233,7 +221,7 @@ def mteb_test_embed_models( st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) st_dtype = next(hf_model.model.parameters()).dtype - # Test embed_dims and whether to use normalize + # Check embeddings close to hf outputs hf_outputs = hf_model.encode(example_prompts) check_embeddings_close( embeddings_0_lst=hf_outputs, @@ -295,7 +283,7 @@ def mteb_test_rerank_models_hf( original_predict = hf_model.predict def _predict( - sentences: list[tuple[str, str, Optional[str]]], # query, corpus, prompt + sentences: list[tuple[str, str, str | None]], # query, corpus, prompt *args, **kwargs, ): @@ -325,31 +313,13 @@ def mteb_test_rerank_models( vllm_mteb_encoder=VllmMtebEncoder, atol=MTEB_RERANK_TOL, ): - # A model family has many models with the same architecture, - # and we don't need to test each one. - if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: - pytest.skip("Skipping test.") - - # Allow vllm to test using the given dtype, such as float32 - vllm_extra_kwargs = vllm_extra_kwargs or {} - vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype - - # Allow vllm to test using hf_overrides - if model_info.hf_overrides is not None: - vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides - - # Allow changing the head dtype used by vllm in tests - if ci_envs.VLLM_CI_HEAD_DTYPE is not None: - if "hf_overrides" not in vllm_extra_kwargs: - vllm_extra_kwargs["hf_overrides"] = {} - vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs) with vllm_runner( model_info.name, runner="pooling", max_model_len=None, max_num_seqs=8, - enforce_eager=True, **vllm_extra_kwargs, ) as vllm_model: model_config = vllm_model.llm.llm_engine.model_config diff --git a/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py index 9e95dd74c397..2927a3711136 100644 --- a/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py +++ b/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Any import numpy as np import pytest @@ -111,7 +111,7 @@ def __init__(self, *args, **kwargs): def predict( self, - sentences: list[tuple[str, str, Optional[str]]], # query, corpus, prompt + sentences: list[tuple[str, str, str | None]], # query, corpus, prompt *args, **kwargs, ) -> np.ndarray: diff --git a/tests/models/language/pooling_mteb_test/test_jina.py b/tests/models/language/pooling_mteb_test/test_jina.py index 0a712b2542f3..c2065bcd6eb4 100644 --- a/tests/models/language/pooling_mteb_test/test_jina.py +++ b/tests/models/language/pooling_mteb_test/test_jina.py @@ -25,6 +25,7 @@ mteb_score=0.824413164, architecture="XLMRobertaModel", is_matryoshka=True, + dtype="float32", ) ] diff --git a/tests/models/language/pooling_mteb_test/test_st_projector.py b/tests/models/language/pooling_mteb_test/test_st_projector.py index 91b1ef828d0d..74fe4b9bcc03 100644 --- a/tests/models/language/pooling_mteb_test/test_st_projector.py +++ b/tests/models/language/pooling_mteb_test/test_st_projector.py @@ -23,6 +23,7 @@ architecture="Gemma3TextModel", mteb_score=0.7473819294684156, enable_test=True, + dtype="float32", ), ] diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 475c2ad55f73..f11f75418e7d 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -17,7 +17,7 @@ ) from vllm.platforms import current_platform -from vllm.utils import identity +from vllm.utils.func_utils import identity from ....conftest import ( IMAGE_ASSETS, @@ -109,8 +109,7 @@ limit_mm_per_prompt={"image": 4}, ) ], - # TODO: Revert to "auto" when CPU backend can use torch > 2.6 - dtype="bfloat16" if current_platform.is_cpu() else "auto", + vllm_runner_kwargs={"enable_mm_embeds": True}, marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), "paligemma": VLMTestInfo( @@ -707,8 +706,6 @@ max_num_seqs=2, vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output, prompt_path_encoder=model_utils.qwen_prompt_path_encoder, - # FIXME: https://github.com/huggingface/transformers/issues/38358 - marks=[pytest.mark.skip("Model initialization fails")], ), "qwen2_vl": VLMTestInfo( models=["Qwen/Qwen2-VL-2B-Instruct"], @@ -749,6 +746,7 @@ max_num_seqs=2, auto_cls=AutoModelForImageTextToText, hf_output_post_proc=model_utils.smolvlm_trunc_hf_output, + num_logprobs=10, ), "tarsier": VLMTestInfo( models=["omni-research/Tarsier-7b"], diff --git a/tests/models/multimodal/generation/test_granite_speech.py b/tests/models/multimodal/generation/test_granite_speech.py index ef08b1916aa5..e39dfc888779 100644 --- a/tests/models/multimodal/generation/test_granite_speech.py +++ b/tests/models/multimodal/generation/test_granite_speech.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import Optional import pytest from transformers import AutoModelForSpeechSeq2Seq @@ -18,8 +17,8 @@ def vllm_to_hf_output( - vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], -) -> tuple[list[int], str, Optional[SampleLogprobs]]: + vllm_output: tuple[list[int], str, SampleLogprobs | None], +) -> tuple[list[int], str, SampleLogprobs | None]: """Sanitize hf output to be comparable with vllm output.""" output_ids, output_str, out_logprobs = vllm_output @@ -46,7 +45,7 @@ def run_test( max_tokens: int, num_logprobs: int, tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, + distributed_executor_backend: str | None = None, ): """Inference result should be the same between hf and vllm. diff --git a/tests/models/multimodal/generation/test_phi4_multimodal.py b/tests/models/multimodal/generation/test_phi4_multimodal.py index 132c69285c5c..cbc7dfca0234 100644 --- a/tests/models/multimodal/generation/test_phi4_multimodal.py +++ b/tests/models/multimodal/generation/test_phi4_multimodal.py @@ -3,7 +3,6 @@ import os from collections.abc import Sequence -from typing import Optional import librosa import pytest @@ -57,7 +56,7 @@ def run_test( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - inputs: Sequence[tuple[list[str], PromptImageInput, Optional[PromptAudioInput]]], + inputs: Sequence[tuple[list[str], PromptImageInput, PromptAudioInput | None]], model: str, *, max_model_len: int, @@ -66,7 +65,7 @@ def run_test( num_logprobs: int, mm_limit: int, tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, + distributed_executor_backend: str | None = None, ): """Inference result should be the same between hf and vllm. diff --git a/tests/models/multimodal/generation/test_phi4mm.py b/tests/models/multimodal/generation/test_phi4mm.py index e69d44c6a131..5619cecc081d 100644 --- a/tests/models/multimodal/generation/test_phi4mm.py +++ b/tests/models/multimodal/generation/test_phi4mm.py @@ -3,7 +3,6 @@ import os from collections.abc import Sequence -from typing import Optional import librosa import pytest @@ -48,7 +47,7 @@ def vllm_to_hf_output( - vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], model: str + vllm_output: tuple[list[int], str, SampleLogprobs | None], model: str ): """Sanitize vllm output to be comparable with hf output.""" _, output_str, out_logprobs = vllm_output @@ -79,7 +78,7 @@ def vllm_to_hf_output( def run_test( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - inputs: Sequence[tuple[list[str], PromptImageInput, Optional[PromptAudioInput]]], + inputs: Sequence[tuple[list[str], PromptImageInput, PromptAudioInput | None]], model: str, *, max_model_len: int, @@ -88,7 +87,7 @@ def run_test( num_logprobs: int, mm_limit: int, tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, + distributed_executor_backend: str | None = None, ): """Inference result should be the same between hf and vllm. diff --git a/tests/models/multimodal/generation/test_pixtral.py b/tests/models/multimodal/generation/test_pixtral.py index db0effdaf666..3cad2c43d562 100644 --- a/tests/models/multimodal/generation/test_pixtral.py +++ b/tests/models/multimodal/generation/test_pixtral.py @@ -2,11 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json from dataclasses import asdict -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import pytest from mistral_common.multimodal import download_image -from mistral_common.protocol.instruct.messages import ImageURLChunk +from mistral_common.protocol.instruct.chunk import ImageURLChunk from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.multimodal import image_from_chunk @@ -117,7 +117,7 @@ def _create_engine_inputs_hf(urls: list[str]) -> TextPrompt: MISTRAL_SMALL_3_1_ID: FIXTURES_PATH / "mistral_small_3_chat.json", } -OutputsLogprobs = list[tuple[list[int], str, Optional[SampleLogprobs]]] +OutputsLogprobs = list[tuple[list[int], str, SampleLogprobs | None]] # For the test author to store golden output in JSON diff --git a/tests/models/multimodal/generation/test_qwen2_vl.py b/tests/models/multimodal/generation/test_qwen2_vl.py index a8f0ba870185..e10b8e1e77af 100644 --- a/tests/models/multimodal/generation/test_qwen2_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_vl.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional, TypedDict, Union +from typing import Any, TypedDict import numpy.typing as npt import pytest @@ -83,7 +83,7 @@ class Qwen2VLPromptVideoEmbeddingInput(TypedDict): def batch_make_image_embeddings( - image_batches: list[Union[Image.Image, list[Image.Image]]], + image_batches: list[Image.Image | list[Image.Image]], processor, llm: VllmRunner, ) -> list[Qwen2VLPromptImageEmbeddingInput]: @@ -272,7 +272,7 @@ def run_embedding_input_test( num_logprobs: int, mm_limit: int, tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, + distributed_executor_backend: str | None = None, ): """Inference result should be the same between original image/video input and image/video embeddings input. @@ -292,6 +292,7 @@ def run_embedding_input_test( tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, default_torch_num_threads=1, + enable_mm_embeds=True, ) as vllm_model: outputs_per_case_for_original_input = [ vllm_model.generate_greedy_logprobs( diff --git a/tests/models/multimodal/generation/test_voxtral.py b/tests/models/multimodal/generation/test_voxtral.py index d27b3ab5ff47..18a50c3a555d 100644 --- a/tests/models/multimodal/generation/test_voxtral.py +++ b/tests/models/multimodal/generation/test_voxtral.py @@ -6,12 +6,8 @@ import pytest import pytest_asyncio from mistral_common.audio import Audio -from mistral_common.protocol.instruct.messages import ( - AudioChunk, - RawAudio, - TextChunk, - UserMessage, -) +from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk +from mistral_common.protocol.instruct.messages import UserMessage from vllm.transformers_utils.tokenizer import MistralTokenizer diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index 766f09b0d320..eca2b61e37d5 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import pytest @@ -92,7 +91,7 @@ def run_test( model: str, *, tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, + distributed_executor_backend: str | None = None, ) -> None: prompt_list = PROMPTS * 10 expected_list = EXPECTED[model] * 10 diff --git a/tests/models/multimodal/generation/vlm_utils/builders.py b/tests/models/multimodal/generation/vlm_utils/builders.py index 096931cca09f..6252f33bdfad 100644 --- a/tests/models/multimodal/generation/vlm_utils/builders.py +++ b/tests/models/multimodal/generation/vlm_utils/builders.py @@ -2,9 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Helpers for building inputs that can be leveraged for different test types.""" -from collections.abc import Iterable +from collections.abc import Callable, Iterable from pathlib import PosixPath -from typing import Callable, Optional, Union import torch @@ -47,9 +46,9 @@ def replace_test_placeholder( def get_model_prompts( base_prompts: Iterable[str], - img_idx_to_prompt: Optional[Callable[[int], str]], - video_idx_to_prompt: Optional[Callable[[int], str]], - audio_idx_to_prompt: Optional[Callable[[int], str]], + img_idx_to_prompt: Callable[[int], str] | None, + video_idx_to_prompt: Callable[[int], str] | None, + audio_idx_to_prompt: Callable[[int], str] | None, prompt_formatter: Callable[[str], str], ) -> list[str]: """Given a model-agnostic base prompt and test configuration for a model(s) @@ -93,7 +92,7 @@ def build_single_image_inputs_from_test_info( test_info: VLMTestInfo, image_assets: ImageTestAssets, size_wrapper: ImageSizeWrapper, - tmp_path: Optional[PosixPath] = None, + tmp_path: PosixPath | None = None, ) -> list[PromptWithMultiModalInput]: if test_info.prompt_formatter is None: raise ValueError("Prompt formatter must be set to build single image inputs") @@ -147,7 +146,7 @@ def build_multi_image_inputs_from_test_info( test_info: VLMTestInfo, image_assets: ImageTestAssets, size_wrapper: ImageSizeWrapper, - tmp_path: Optional[PosixPath] = None, + tmp_path: PosixPath | None = None, ) -> list[PromptWithMultiModalInput]: if test_info.prompt_formatter is None: raise ValueError("Prompt formatter must be set to build multi image inputs") @@ -266,9 +265,7 @@ def build_video_inputs_from_test_info( ] -def apply_image_size_scaling( - image, size: Union[float, tuple[int, int]], size_type: SizeType -): +def apply_image_size_scaling(image, size: float | tuple[int, int], size_type: SizeType): """Applies a size scaler to one image; this can be an image size factor, which scales the image while maintaining the aspect ratio""" # Special case for embeddings; if it's a tensor, it's only valid if we diff --git a/tests/models/multimodal/generation/vlm_utils/core.py b/tests/models/multimodal/generation/vlm_utils/core.py index 0c11f5f9b082..03ff3bcf6307 100644 --- a/tests/models/multimodal/generation/vlm_utils/core.py +++ b/tests/models/multimodal/generation/vlm_utils/core.py @@ -2,12 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Core test implementation to be shared across modalities.""" -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any import torch from transformers.models.auto.auto_factory import _BaseAutoModelClass -from vllm.config import RunnerOption +from vllm.config.model import RunnerOption from vllm.transformers_utils.tokenizer import AnyTokenizer from .....conftest import HfRunner, VllmRunner @@ -27,21 +28,21 @@ def run_test( enforce_eager: bool, max_model_len: int, max_num_seqs: int, - hf_output_post_proc: Optional[Callable[[RunnerOutput, str], Any]], - vllm_output_post_proc: Optional[Callable[[RunnerOutput, str], Any]], + hf_output_post_proc: Callable[[RunnerOutput, str], Any] | None, + vllm_output_post_proc: Callable[[RunnerOutput, str], Any] | None, auto_cls: type[_BaseAutoModelClass], use_tokenizer_eos: bool, comparator: Callable[..., None], - get_stop_token_ids: Optional[Callable[[AnyTokenizer], list[int]]], - stop_str: Optional[list[str]], + get_stop_token_ids: Callable[[AnyTokenizer], list[int]] | None, + stop_str: list[str] | None, limit_mm_per_prompt: dict[str, int], - vllm_runner_kwargs: Optional[dict[str, Any]], - hf_model_kwargs: Optional[dict[str, Any]], - patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]], + vllm_runner_kwargs: dict[str, Any] | None, + hf_model_kwargs: dict[str, Any] | None, + patch_hf_runner: Callable[[HfRunner], HfRunner] | None, runner: RunnerOption = "auto", - distributed_executor_backend: Optional[str] = None, + distributed_executor_backend: str | None = None, tensor_parallel_size: int = 1, - vllm_embeddings: Optional[torch.Tensor] = None, + vllm_embeddings: torch.Tensor | None = None, ): """Modality agnostic test executor for comparing HF/vLLM outputs.""" # In the case of embeddings, vLLM takes separate input tensors @@ -70,8 +71,9 @@ def run_test( vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode if model_info.hf_overrides: vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides - if model_info.skip_tokenizer_init: - vllm_runner_kwargs_["skip_tokenizer_init"] = model_info.skip_tokenizer_init + if model_info.require_embed_inputs: + for k in ("skip_tokenizer_init", "enable_prompt_embeds", "enable_mm_embeds"): + vllm_runner_kwargs_[k] = model_info.require_embed_inputs if vllm_runner_kwargs: vllm_runner_kwargs_.update(vllm_runner_kwargs) diff --git a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py index 8f2f8bba39ca..8c9c390911bd 100644 --- a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py +++ b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Custom input builders for edge-cases in different models.""" -from typing import Callable +from collections.abc import Callable from vllm.assets.image import ImageAsset from vllm.multimodal.image import rescale_image_size diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index e51d895772c0..0685a01da58f 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -7,7 +7,6 @@ import types from pathlib import PosixPath -from typing import Optional, Union import numpy as np import numpy.typing as npt @@ -26,7 +25,7 @@ from transformers.video_utils import VideoMetadata from vllm.logprobs import SampleLogprobs -from vllm.utils import is_list_of +from vllm.utils.collection_utils import is_list_of from .....conftest import HfRunner, ImageAsset, ImageTestAssets from .types import RunnerOutput @@ -58,7 +57,7 @@ def fuyu_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutpu def qwen_vllm_to_hf_output( vllm_output: RunnerOutput, model: str -) -> tuple[list[int], str, Optional[SampleLogprobs]]: +) -> tuple[list[int], str, SampleLogprobs | None]: """Sanitize vllm output [qwen models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -69,7 +68,7 @@ def qwen_vllm_to_hf_output( def qwen2_vllm_to_hf_output( vllm_output: RunnerOutput, model: str -) -> tuple[list[int], str, Optional[SampleLogprobs]]: +) -> tuple[list[int], str, SampleLogprobs | None]: """Sanitize vllm output [qwen2 models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -80,7 +79,7 @@ def qwen2_vllm_to_hf_output( def kimiv_vl_vllm_to_hf_output( vllm_output: RunnerOutput, model: str -) -> tuple[list[int], str, Optional[SampleLogprobs]]: +) -> tuple[list[int], str, SampleLogprobs | None]: """Sanitize vllm output [kimi_vl models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -99,7 +98,7 @@ def llava_image_vllm_to_hf_output( def llava_video_vllm_to_hf_output( vllm_output: RunnerOutput, model: str -) -> tuple[list[int], str, Optional[SampleLogprobs]]: +) -> tuple[list[int], str, SampleLogprobs | None]: config = AutoConfig.from_pretrained(model) mm_token_id = config.video_token_index return _llava_vllm_to_hf_output(vllm_output, model, mm_token_id) @@ -263,7 +262,7 @@ def get_llava_embeddings(image_assets: ImageTestAssets): ####### Prompt path encoders for models that need models on disk def qwen_prompt_path_encoder( - tmp_path: PosixPath, prompt: str, assets: Union[list[ImageAsset], ImageTestAssets] + tmp_path: PosixPath, prompt: str, assets: list[ImageAsset] | ImageTestAssets ) -> str: """Given a temporary dir path, export one or more image assets into the tempdir & replace its contents with the local path to the string so that @@ -440,7 +439,7 @@ def __init__(self, hf_runner: HfRunner): self.max_num = self.config.max_dynamic_patch self.image_size = self.vision_config.image_size - def __call__(self, text: str, images: Union[Image, list[Image]], **kwargs): + def __call__(self, text: str, images: Image | list[Image], **kwargs): from vllm.model_executor.models.h2ovl import ( IMG_CONTEXT, IMG_END, @@ -499,7 +498,7 @@ def __init__(self, hf_runner: HfRunner): self.max_num = self.config.max_dynamic_patch self.image_size = self.vision_config.image_size - def __call__(self, text: str, images: Union[Image, list[Image]], **kwargs): + def __call__(self, text: str, images: Image | list[Image], **kwargs): from vllm.model_executor.models.skyworkr1v import ( IMG_CONTEXT, IMG_END, @@ -560,8 +559,8 @@ def __init__(self, hf_runner: HfRunner): def __call__( self, text: str, - images: Union[Image, list[Image]] = None, - videos: Union[npt.NDArray, list[npt.NDArray]] = None, + images: Image | list[Image] = None, + videos: npt.NDArray | list[npt.NDArray] = None, **kwargs, ): from vllm.model_executor.models.internvl import ( @@ -650,7 +649,7 @@ def _internvl_generate( self, pixel_values: torch.FloatTensor, input_ids: torch.FloatTensor, - attention_mask: Optional[torch.LongTensor] = None, + attention_mask: torch.LongTensor | None = None, **generate_kwargs, ) -> torch.LongTensor: """Generate method for InternVL2 model without fixed use_cache.""" diff --git a/tests/models/multimodal/generation/vlm_utils/types.py b/tests/models/multimodal/generation/vlm_utils/types.py index bb34d1cc6dad..fe02f7188432 100644 --- a/tests/models/multimodal/generation/vlm_utils/types.py +++ b/tests/models/multimodal/generation/vlm_utils/types.py @@ -2,17 +2,17 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Types for writing multimodal model tests.""" -from collections.abc import Iterable +from collections.abc import Callable, Iterable from enum import Enum from pathlib import PosixPath -from typing import Any, Callable, NamedTuple, Optional, Union +from typing import Any, NamedTuple import torch from pytest import MarkDecorator from transformers import AutoModelForCausalLM from transformers.models.auto.auto_factory import _BaseAutoModelClass -from vllm.config import RunnerOption +from vllm.config.model import RunnerOption from vllm.logprobs import SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -52,16 +52,16 @@ IMAGE_SIZE_FACTORS = [(), (1.0,), (1.0, 1.0, 1.0), (0.25, 0.5, 1.0)] EMBEDDING_SIZE_FACTORS = [(), (1.0,), (1.0, 1.0, 1.0)] -RunnerOutput = tuple[list[int], str, Optional[SampleLogprobs]] +RunnerOutput = tuple[list[int], str, SampleLogprobs | None] class PromptWithMultiModalInput(NamedTuple): """Holds the multimodal input for a single test case.""" prompts: list[str] - image_data: Optional[PromptImageInput] = None - video_data: Optional[PromptVideoInput] = None - audio_data: Optional[PromptAudioInput] = None + image_data: PromptImageInput | None = None + video_data: PromptVideoInput | None = None + audio_data: PromptAudioInput | None = None class VLMTestType(Enum): @@ -87,17 +87,17 @@ class ImageSizeWrapper(NamedTuple): type: SizeType # A size factor is a wrapper of 0+ floats, # while a fixed size contains an iterable of integer pairs - data: Union[Iterable[float], Iterable[tuple[int, int]]] + data: Iterable[float] | Iterable[tuple[int, int]] class VLMTestInfo(NamedTuple): """Holds the configuration for 1+ tests for one model architecture.""" models: list[str] - test_type: Union[VLMTestType, Iterable[VLMTestType]] + test_type: VLMTestType | Iterable[VLMTestType] # Should be None only if this is a CUSTOM_INPUTS test - prompt_formatter: Optional[Callable[[str], str]] = None + prompt_formatter: Callable[[str], str] | None = None img_idx_to_prompt: Callable[[int], str] = lambda idx: "<image>\n" video_idx_to_prompt: Callable[[int], str] = lambda idx: "<video>\n" audio_idx_to_prompt: Callable[[int], str] = lambda idx: "<audio>\n" @@ -111,9 +111,9 @@ class VLMTestInfo(NamedTuple): # Function for converting ImageAssets to image embeddings; # We need to define this explicitly for embedding tests - convert_assets_to_embeddings: Optional[ - Callable[[ImageTestAssets], list[torch.Tensor]] - ] = None + convert_assets_to_embeddings: ( + Callable[[ImageTestAssets], list[torch.Tensor]] | None + ) = None # Exposed options for vLLM runner; we change these in a several tests, # but the defaults are derived from VllmRunner & the engine defaults @@ -123,25 +123,25 @@ class VLMTestInfo(NamedTuple): max_num_seqs: int = 256 runner: RunnerOption = "auto" tensor_parallel_size: int = 1 - vllm_runner_kwargs: Optional[dict[str, Any]] = None + vllm_runner_kwargs: dict[str, Any] | None = None # Optional callable which gets a list of token IDs from the model tokenizer - get_stop_token_ids: Optional[Callable[[AnyTokenizer], list[int]]] = None + get_stop_token_ids: Callable[[AnyTokenizer], list[int]] | None = None # Optional list of strings to stop generation, useful when stop tokens are # not special tokens in the tokenizer - stop_str: Optional[list[str]] = None + stop_str: list[str] | None = None # Exposed options for HF runner - hf_model_kwargs: Optional[dict[str, Any]] = None + hf_model_kwargs: dict[str, Any] | None = None # Indicates we should explicitly pass the EOS from the tokenizer use_tokenizer_eos: bool = False auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM - patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]] = None + patch_hf_runner: Callable[[HfRunner], HfRunner] | None = None # Post processors that if defined, will run oun the outputs of the # vLLM and HF runner, respectively (useful for sanitization, etc). - vllm_output_post_proc: Optional[Callable[[RunnerOutput, str], Any]] = None - hf_output_post_proc: Optional[Callable[[RunnerOutput, str], Any]] = None + vllm_output_post_proc: Callable[[RunnerOutput, str], Any] | None = None + hf_output_post_proc: Callable[[RunnerOutput, str], Any] | None = None # Consumes the output of the callables above and checks if they're equal comparator: Callable[..., None] = check_logprobs_close @@ -152,7 +152,7 @@ class VLMTestInfo(NamedTuple): max_tokens: int = 128 num_logprobs: int = 5 dtype: str = "auto" - distributed_executor_backend: Optional[str] = None + distributed_executor_backend: str | None = None # Only expanded in video tests num_video_frames: int = 16 @@ -162,19 +162,19 @@ class VLMTestInfo(NamedTuple): # once per tests (much like concatenating and wrapping in one parametrize # call) image_size_factors: Iterable[Iterable[float]] = IMAGE_SIZE_FACTORS - image_sizes: Optional[Iterable[Iterable[tuple[int, int]]]] = None + image_sizes: Iterable[Iterable[tuple[int, int]]] | None = None # Hack for updating a prompt to take into a local path; currently only used # for Qwen-VL, which requires encoding the image path / url into the prompt # for HF runner - prompt_path_encoder: Optional[ - Callable[[PosixPath, str, Union[list[ImageAsset], ImageTestAssets]], str] - ] = None # noqa: E501 + prompt_path_encoder: ( + Callable[[PosixPath, str, list[ImageAsset] | ImageTestAssets], str] | None + ) = None # noqa: E501 # Allows configuring a test to run with custom inputs - custom_test_opts: Optional[list[CustomTestOptions]] = None + custom_test_opts: list[CustomTestOptions] | None = None - marks: Optional[list[MarkDecorator]] = None + marks: list[MarkDecorator] | None = None def get_non_parametrized_runner_kwargs(self): """Returns a dictionary of expandable kwargs for items that are used @@ -207,10 +207,10 @@ class ExpandableVLMTestArgs(NamedTuple): max_tokens: int num_logprobs: int dtype: str - distributed_executor_backend: Optional[str] + distributed_executor_backend: str | None # Sizes are used for everything except for custom input tests - size_wrapper: Optional[ImageSizeWrapper] = None + size_wrapper: ImageSizeWrapper | None = None # Video only - num_video_frames: Optional[int] = None + num_video_frames: int | None = None # Custom inputs only - custom_test_opts: Optional[CustomTestOptions] = None + custom_test_opts: CustomTestOptions | None = None diff --git a/tests/models/multimodal/pooling/test_clip.py b/tests/models/multimodal/pooling/test_clip.py index b8c6c4abace9..95c678558f4f 100644 --- a/tests/models/multimodal/pooling/test_clip.py +++ b/tests/models/multimodal/pooling/test_clip.py @@ -45,14 +45,16 @@ def _run_test( all_outputs = [] for inputs in all_inputs: + inputs = hf_model.wrap_device(inputs) + if "pixel_values" in inputs: - inputs.pop("input_ids") pooled_output = hf_model.model.get_image_features( - **hf_model.wrap_device(inputs) + pixel_values=inputs.pixel_values, ).squeeze(0) else: pooled_output = hf_model.model.get_text_features( - **hf_model.wrap_device(inputs) + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, ).squeeze(0) all_outputs.append(pooled_output.tolist()) diff --git a/tests/models/multimodal/pooling/test_dse_qwen2_vl.py b/tests/models/multimodal/pooling/test_dse_qwen2_vl.py index 7f30b1f299ba..ac3eb6e61723 100644 --- a/tests/models/multimodal/pooling/test_dse_qwen2_vl.py +++ b/tests/models/multimodal/pooling/test_dse_qwen2_vl.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable +from collections.abc import Callable import pytest import torch diff --git a/tests/models/multimodal/pooling/test_intern_vit.py b/tests/models/multimodal/pooling/test_intern_vit.py index b474e851319a..5a97848216b8 100644 --- a/tests/models/multimodal/pooling/test_intern_vit.py +++ b/tests/models/multimodal/pooling/test_intern_vit.py @@ -7,7 +7,7 @@ from transformers import AutoConfig, AutoModel, CLIPImageProcessor from vllm.distributed import cleanup_dist_env_and_memory -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from ....conftest import ImageTestAssets @@ -38,7 +38,7 @@ def run_intern_vit_test( config.norm_type = "rms_norm" hf_model = AutoModel.from_pretrained( - model, torch_dtype=torch_dtype, trust_remote_code=True + model, dtype=torch_dtype, trust_remote_code=True ).to("cuda") hf_outputs_per_image = [ hf_model(pixel_value.to("cuda")).last_hidden_state diff --git a/tests/models/multimodal/pooling/test_jinavl_reranker.py b/tests/models/multimodal/pooling/test_jinavl_reranker.py index 853f56618290..d7b33be7a0ad 100644 --- a/tests/models/multimodal/pooling/test_jinavl_reranker.py +++ b/tests/models/multimodal/pooling/test_jinavl_reranker.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Union import pytest from transformers import AutoModel @@ -32,7 +31,7 @@ def vllm_reranker( def create_image_param(url: str) -> ChatCompletionContentPartImageParam: return {"type": "image_url", "image_url": {"url": f"{url}"}} - query: Union[list[str], ScoreMultiModalParam] + query: list[str] | ScoreMultiModalParam if query_type == "text": query = query_strs elif query_type == "image": @@ -40,7 +39,7 @@ def create_image_param(url: str) -> ChatCompletionContentPartImageParam: content=[create_image_param(url) for url in query_strs] ) - documents: Union[list[str], ScoreMultiModalParam] + documents: list[str] | ScoreMultiModalParam if doc_type == "text": documents = document_strs elif doc_type == "image": diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py index abf4150a9132..5082827962d8 100644 --- a/tests/models/multimodal/pooling/test_prithvi_mae.py +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -34,12 +34,13 @@ def _run_test( dtype="half", enforce_eager=True, skip_tokenizer_init=True, + enable_mm_embeds=True, # Limit the maximum number of sequences to avoid the # test going OOM during the warmup run max_num_seqs=32, default_torch_num_threads=1, ) as vllm_model: - vllm_model.encode(prompt) + vllm_model.llm.encode(prompt, pooling_task="plugin") MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"] diff --git a/tests/models/multimodal/pooling/test_radio.py b/tests/models/multimodal/pooling/test_radio.py index 80f594021ca8..8929563d8b05 100644 --- a/tests/models/multimodal/pooling/test_radio.py +++ b/tests/models/multimodal/pooling/test_radio.py @@ -9,7 +9,7 @@ from vllm.distributed import cleanup_dist_env_and_memory from vllm.model_executor.models.radio import RadioModel from vllm.transformers_utils.configs.radio import RadioConfig -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from ....conftest import ImageTestAssets @@ -45,7 +45,7 @@ def run_radio_test( hf_model = AutoModel.from_pretrained( model_id, config=config, - torch_dtype=torch_dtype, + dtype=torch_dtype, trust_remote_code=True, ).to("cuda") hf_model.eval() diff --git a/tests/models/multimodal/pooling/test_siglip.py b/tests/models/multimodal/pooling/test_siglip.py new file mode 100644 index 000000000000..f681b4787b69 --- /dev/null +++ b/tests/models/multimodal/pooling/test_siglip.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import SiglipModel + +from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner +from ...utils import check_embeddings_close + +HF_TEXT_PROMPTS = [ + "a photo of a stop sign", + "a photo of a cherry blossom", +] + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "", + "cherry_blossom": "", + } +) + +MODELS = ["google/siglip-base-patch16-224"] + + +def _run_test( + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + input_texts: list[str], + input_images: PromptImageInput, + model: str, + *, + dtype: str, +) -> None: + with vllm_runner( + model, runner="pooling", dtype=dtype, enforce_eager=True, max_model_len=64 + ) as vllm_model: + vllm_outputs = vllm_model.embed(input_texts, images=input_images) + + with hf_runner(model, dtype=dtype, auto_cls=SiglipModel) as hf_model: + all_inputs = hf_model.get_inputs(input_texts, images=input_images) + + all_outputs = [] + for inputs in all_inputs: + inputs = hf_model.wrap_device(inputs) + + if "pixel_values" in inputs: + pooled_output = hf_model.model.get_image_features( + pixel_values=inputs.pixel_values, + ).squeeze(0) + else: + pooled_output = hf_model.model.get_text_features( + input_ids=inputs.input_ids, + ).squeeze(0) + + all_outputs.append(pooled_output.tolist()) + + hf_outputs = all_outputs + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_models_text( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, # type: ignore + model, + dtype=dtype, + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_models_image( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [ + (text, asset.pil_image) for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) + ] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, + model, + dtype=dtype, + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_models_text_image_no_crash( + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + texts = [HF_TEXT_PROMPTS[0]] + images = [image_assets[0].pil_image] + + with vllm_runner( + model, + runner="pooling", + dtype=dtype, + enforce_eager=True, + max_model_len=64, + ) as vllm_model: + with pytest.raises(ValueError, match="not both"): + vllm_model.embed(texts, images=images) + + vllm_model.embed(texts) + vllm_model.embed([""], images=images) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index d9d85f7e0c00..313ab2fa8038 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Set as AbstractSet from functools import partial -from typing import Optional, Union import numpy as np import pytest -from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage +from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk +from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from PIL import Image @@ -22,14 +23,17 @@ from vllm.multimodal.inputs import MultiModalInputs from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext from vllm.transformers_utils.tokenizer import ( - AnyTokenizer, MistralTokenizer, cached_tokenizer_from_config, encode_tokens, ) from ....multimodal.utils import random_audio, random_image, random_video -from ...registry import HF_EXAMPLE_MODELS +from ...registry import ( + _MULTIMODAL_EXAMPLE_MODELS, + _TRANSFORMERS_BACKEND_MODELS, + HF_EXAMPLE_MODELS, +) def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: @@ -83,6 +87,119 @@ def create_metadata(frames: np.ndarray): return mm_data +# For some multimodal models, tokenizer will always add bos_token +# at the beginning of prompt by default, causing hf_processor outputs +# incorrect token ids. So we need use `add_special_tokens=False` here +# to leave bos_token to be added by the processor. +_ADD_SPECIAL_TOKENS_OVERRIDES = { + "ovis": False, + "ovis2_5": False, + "paligemma": False, + "ultravox": False, + "whisper": False, +} + +_IGNORE_MM_KEYS = { + # In Ultravox, the audio_features can be different depending on padding + # The slight difference should not be a problem though, since + # attention_mask lets us ignore the difference. + "ultravox": {"audio_features"}, +} + +MM_DATA_PATCHES = { + # GLM4.1V and Qwen3-VL requires video metadata to be included in the input + "glm4v": glm4_1v_patch_mm_data, + "glm4v_moe": glm4_1v_patch_mm_data, + "qwen3_vl": qwen3_vl_patch_mm_data, + "qwen3_vl_moe": qwen3_vl_patch_mm_data, +} + + +def _iter_model_ids_to_test(model_arch_list: AbstractSet[str]): + for model_arch in model_arch_list: + model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) + yield model_info.default + + for extra_type, extra_model_id in model_info.extras.items(): + if "fp" in extra_type: + continue # Redundant to test quantized models + + yield extra_model_id + + +def _get_model_ids_to_test(model_arch_list: AbstractSet[str]): + return list(_iter_model_ids_to_test(model_arch_list)) + + +def get_model_ids_to_test(): + transformers_arch_ids = { + model_id + for info in _TRANSFORMERS_BACKEND_MODELS.values() + for model_id in (info.default, *info.extras.values()) + } + vllm_only_archs = { + arch + for arch, info in _MULTIMODAL_EXAMPLE_MODELS.items() + if not any( + model_id in transformers_arch_ids + for model_id in (info.default, *info.extras.values()) + ) + } + + return _get_model_ids_to_test(vllm_only_archs) + + +def get_text_token_prompts( + processor: BaseMultiModalProcessor, + mm_data: MultiModalDataDict, +): + dummy_inputs = processor.dummy_inputs + tokenizer = processor.info.get_tokenizer() + model_config = processor.info.ctx.model_config + + model_type = model_config.hf_config.model_type + if model_type in MM_DATA_PATCHES: + mm_data = MM_DATA_PATCHES[model_type](mm_data) + + parsed_data = processor.data_parser.parse_mm_data(mm_data) + mm_counts = {k: len(vs) for k, vs in parsed_data.items()} + + text_prompt: str | None + token_prompt: list[int] + if isinstance(tokenizer, MistralTokenizer): + images = parsed_data.get("image", []) + request = ChatCompletionRequest( + messages=[ + UserMessage( + content=[ + TextChunk(text=""), + *(ImageChunk(image=image) for image in images), + ] + ), + ] + ) + res = tokenizer.mistral.encode_chat_completion(request) + + # Mistral does not support decode_tokens with skip_special_tokens=False + text_prompt = None + token_prompt = res.tokens + else: + inputs = dummy_inputs.get_dummy_processor_inputs( + model_config.max_model_len, + mm_counts, + ) + assert isinstance(inputs.prompt, str) + + text_prompt = inputs.prompt + token_prompt = encode_tokens( + tokenizer, + text_prompt, + add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type), + ) + + return text_prompt, token_prompt + + def _test_processing_correctness( model_id_or_arch: str, hit_rate: float, @@ -108,7 +225,9 @@ def _test_processing_correctness( hf_overrides=model_info.hf_overrides, # Ensure that the cache can fit all of the data mm_processor_cache_gb=2048, - skip_tokenizer_init=model_info.skip_tokenizer_init, + skip_tokenizer_init=model_info.require_embed_inputs, + enable_prompt_embeds=model_info.require_embed_inputs, + enable_mm_embeds=model_info.require_embed_inputs, enforce_eager=model_info.enforce_eager, dtype=model_info.dtype, ) @@ -146,8 +265,6 @@ def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions: baseline_processor = factories.build_processor(ctx, cache=None) cached_processor = factories.build_processor(ctx, cache=cache) - dummy_inputs = baseline_processor.dummy_inputs - tokenizer = baseline_processor.info.get_tokenizer() rng = np.random.RandomState(0) @@ -173,29 +290,6 @@ def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions: for k, limit in limit_mm_per_prompt_ints.items() } - mm_counts = {k: len(vs) for k, vs in mm_data.items()} - - # Mistral chat outputs tokens directly, rather than text prompts - if isinstance(tokenizer, MistralTokenizer): - images = mm_data.get("image", []) - request = ChatCompletionRequest( - messages=[ - UserMessage( - content=[ - TextChunk(text=""), - *(ImageChunk(image=image) for image in images), - ] - ), - ] - ) - res = tokenizer.mistral.encode_chat_completion(request) - prompt = res.tokens - else: - prompt = dummy_inputs.get_dummy_processor_inputs( - model_config.max_model_len, - mm_counts, - ).prompt - # Drop unnecessary keys and test single -> multi conversion if rng.rand() < simplify_rate: for k in list(mm_data.keys()): @@ -206,8 +300,6 @@ def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions: _test_processing_correctness_one( model_config, - tokenizer, - prompt, mm_data, baseline_processor, cached_processor, @@ -215,59 +307,17 @@ def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions: ) -# For some multimodal models, tokenizer will always add bos_token -# at the beginning of prompt by default, causing hf_processor outputs -# incorrect token ids. So we need use `add_special_tokens=False` here -# to leave bos_token to be added by the processor. -_ADD_SPECIAL_TOKENS_OVERRIDES = { - "ovis": False, - "ovis2_5": False, - "paligemma": False, - "ultravox": False, - "whisper": False, -} - -_IGNORE_MM_KEYS = { - # In Ultravox, the audio_features can be different depending on padding - # The slight difference should not be a problem though, since - # attention_mask lets us ignore the difference. - "ultravox": {"audio_features"}, -} - -MM_DATA_PATCHES = { - # GLM4.1V and Qwen3-VL requires video metadata to be included in the input - "glm4v": glm4_1v_patch_mm_data, - "glm4v_moe": glm4_1v_patch_mm_data, - "qwen3_vl": qwen3_vl_patch_mm_data, - "qwen3_vl_moe": qwen3_vl_patch_mm_data, -} - - def _test_processing_correctness_one( model_config: ModelConfig, - tokenizer: AnyTokenizer, - prompt: Union[str, list[int]], mm_data: MultiModalDataDict, baseline_processor: BaseMultiModalProcessor, cached_processor: BaseMultiModalProcessor, batch_idx: int, ): model_type = model_config.hf_config.model_type - ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]()) - if model_type in MM_DATA_PATCHES: - mm_data = MM_DATA_PATCHES[model_type](mm_data) - if isinstance(prompt, str): - text_prompt = prompt - token_prompt = encode_tokens( - tokenizer, - prompt, - add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type), - ) - else: - # Mistral does not support decode_tokens with skip_special_tokens=False - text_prompt = None - token_prompt = prompt + text_prompt, token_prompt = get_text_token_prompts(baseline_processor, mm_data) + ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]()) baseline_tokenized_result = baseline_processor.apply( token_prompt, @@ -322,78 +372,7 @@ def _test_processing_correctness_one( ) -@pytest.mark.parametrize( - "model_id", - [ - "rhymes-ai/Aria", - "CohereForAI/aya-vision-8b", - "Salesforce/blip2-opt-2.7b", - "facebook/chameleon-7b", - "CohereLabs/command-a-vision-07-2025", - "deepseek-ai/deepseek-vl2-tiny", - "baidu/ERNIE-4.5-VL-28B-A3B-PT", - "adept/fuyu-8b", - "google/gemma-3-4b-it", - "google/gemma-3n-E2B-it", - "zai-org/glm-4v-9b", - "zai-org/GLM-4.1V-9B-Thinking", - "zai-org/GLM-4.5V", - "ibm-granite/granite-speech-3.3-2b", - "h2oai/h2ovl-mississippi-800m", - "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", - "HuggingFaceM4/Idefics3-8B-Llama3", - "internlm/Intern-S1", - "OpenGVLab/InternVL2-1B", - "OpenGVLab/InternVL3-1B", - "OpenGVLab/InternVL3_5-1B", - "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview", - "OpenGVLab/InternVL3_5-30B-A3B", - "Kwai-Keye/Keye-VL-8B-Preview", - "Kwai-Keye/Keye-VL-1_5-8B", - "moonshotai/Kimi-VL-A3B-Instruct", - "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "llava-hf/llava-1.5-7b-hf", - "llava-hf/llava-v1.6-mistral-7b-hf", - "llava-hf/LLaVA-NeXT-Video-7B-hf", - "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", - "TIGER-Lab/Mantis-8B-siglip-llama3", - "mispeech/midashenglm-7b", - "openbmb/MiniCPM-Llama3-V-2_5", - "openbmb/MiniCPM-o-2_6", - "openbmb/MiniCPM-V-2_6", - "MiniMaxAI/MiniMax-VL-01", - "allenai/Molmo-7B-D-0924", - "allenai/Molmo-7B-O-0924", - "nvidia/NVLM-D-72B", - "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", - "AIDC-AI/Ovis1.6-Gemma2-9B", - "AIDC-AI/Ovis1.6-Llama3.2-3B", - "AIDC-AI/Ovis2-1B", - "AIDC-AI/Ovis2.5-2B", - "google/paligemma-3b-mix-224", - "google/paligemma2-3b-ft-docci-448", - "microsoft/Phi-3.5-vision-instruct", - "microsoft/Phi-4-multimodal-instruct", - "mistralai/Pixtral-12B-2409", - "mistral-community/pixtral-12b", - "Qwen/Qwen-VL-Chat", - "Qwen/Qwen2-VL-2B-Instruct", - "Qwen/Qwen2.5-VL-3B-Instruct", - "Qwen/Qwen2-Audio-7B-Instruct", - "Qwen/Qwen2.5-Omni-3B", - "Qwen/Qwen3-VL-4B-Instruct", - "Qwen/Qwen3-VL-30B-A3B-Instruct", - "YannQi/R-4B", - "Skywork/Skywork-R1V-38B", - "HuggingFaceTB/SmolVLM2-2.2B-Instruct", - "stepfun-ai/step3", - "fixie-ai/ultravox-v0_5-llama-3_2-1b", - "openai/whisper-large-v3", - "omni-research/Tarsier-7b", - "omni-research/Tarsier2-Recap-7b", - "mistralai/Voxtral-Mini-3B-2507", - ], -) +@pytest.mark.parametrize("model_id", get_model_ids_to_test()) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) @pytest.mark.parametrize("simplify_rate", [1.0]) @@ -404,7 +383,12 @@ def test_processing_correctness( simplify_rate: float, ): if model_id == "google/gemma-3n-E2B-it": - pytest.skip("Skipping gemma-3n-E2B-it due to transformers #39911 bug.") + pytest.skip("Fix later") + if model_id == "OpenGVLab/InternVL2-2B": + pytest.skip("Fix later") + if model_id == "jinaai/jina-reranker-m0": + pytest.skip("Fix later") + _test_processing_correctness( model_id, hit_rate=hit_rate, @@ -439,7 +423,7 @@ def _assert_inputs_equal( a: MultiModalInputs, b: MultiModalInputs, *, - ignore_mm_keys: Optional[set[str]] = None, + ignore_mm_keys: set[str] | None = None, msg: str = "", ): if ignore_mm_keys is None: diff --git a/tests/models/multimodal/processing/test_h2ovl.py b/tests/models/multimodal/processing/test_h2ovl.py index bd21d4008fa7..1701d9dd8f01 100644 --- a/tests/models/multimodal/processing/test_h2ovl.py +++ b/tests/models/multimodal/processing/test_h2ovl.py @@ -3,7 +3,6 @@ """Tests for H2OVL's multimodal preprocessing kwargs.""" from collections.abc import Mapping -from typing import Optional import pytest from PIL import Image @@ -149,7 +148,7 @@ def test_processor_override( size_factors: list[int], min_dynamic_patch: int, max_dynamic_patch: int, - dynamic_image_size: Optional[bool], + dynamic_image_size: bool | None, kwargs_on_init: bool, ): mm_processor_kwargs = { diff --git a/tests/models/multimodal/processing/test_internvl.py b/tests/models/multimodal/processing/test_internvl.py index 6f6529cb9401..b4994295d3a8 100644 --- a/tests/models/multimodal/processing/test_internvl.py +++ b/tests/models/multimodal/processing/test_internvl.py @@ -3,7 +3,6 @@ """Tests for InternVL's multimodal preprocessing kwargs.""" from collections.abc import Mapping -from typing import Optional import pytest from PIL import Image @@ -103,7 +102,7 @@ def test_processor_override( size_factors: list[int], min_dynamic_patch: int, max_dynamic_patch: int, - dynamic_image_size: Optional[bool], + dynamic_image_size: bool | None, kwargs_on_init: bool, ): mm_processor_kwargs = { diff --git a/tests/models/multimodal/processing/test_nemotron_vl.py b/tests/models/multimodal/processing/test_nemotron_vl.py index 6ff6f396fa33..5311ab1b78c6 100644 --- a/tests/models/multimodal/processing/test_nemotron_vl.py +++ b/tests/models/multimodal/processing/test_nemotron_vl.py @@ -3,7 +3,6 @@ """Tests for Nemotron-Nano-VL's multimodal preprocessing kwargs.""" from collections.abc import Mapping -from typing import Optional import pytest from PIL import Image @@ -105,7 +104,7 @@ def test_processor_override( size_factors: list[int], min_dynamic_patch: int, max_dynamic_patch: int, - dynamic_image_size: Optional[bool], + dynamic_image_size: bool | None, kwargs_on_init: bool, ): mm_processor_kwargs = { diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index 2c4d109c3687..687d1ef349f8 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -4,13 +4,11 @@ from collections.abc import Iterable from contextlib import contextmanager from functools import partial -from typing import Any, Union +from typing import Any, TypeAlias import numpy as np import pytest import torch.nn as nn -from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage -from mistral_common.protocol.instruct.request import ChatCompletionRequest from PIL import Image from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config @@ -25,7 +23,6 @@ init_distributed_environment, initialize_model_parallel, ) -from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.interfaces import ( SupportsMultiModal, supports_multimodal, @@ -34,35 +31,35 @@ from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config -from vllm.utils import is_list_of +from vllm.utils.collection_utils import is_list_of +from vllm.utils.torch_utils import set_default_torch_dtype -from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS +from ...registry import HF_EXAMPLE_MODELS from ...utils import dummy_hf_overrides - -ARCH_TO_SKIP = { - "MolmoForCausalLM": "incompatible requirements", -} -ARCH_NEEDS_EXTRAS = [ - "InternVLChatModel", - "Idefics3ForConditionalGeneration", - "LlavaForConditionalGeneration", - "MiniCPMV", - "PaliGemmaForConditionalGeneration", -] -REPO_ID_TO_SKIP = { - "nm-testing/pixtral-12b-FP8-dynamic": "duplicated test", -} +from .test_common import get_model_ids_to_test, get_text_token_prompts ImageInput = list[Image.Image] -VideoInput = Union[ - list[Image.Image], list[np.ndarray], list[tuple[np.ndarray, dict[str, Any]]] -] +VideoInput: TypeAlias = ( + list[Image.Image] | list[np.ndarray] | list[tuple[np.ndarray, dict[str, Any]]] +) AudioInput = list[tuple[np.ndarray, int]] +MM_OPTIONS_OVERRIDES = { + # Qwen3-VL's default profiling video size (64x64) can cause trouble + # after resizing, so we override it here for testing. + "qwen3_vl": dict( + video=VideoDummyOptions(num_frames=128, width=256, height=256), + ), + "qwen3_vl_moe": dict( + video=VideoDummyOptions(num_frames=128, width=256, height=256), + ), +} + + def _resize_data( - _data: Union[Image.Image, np.ndarray], size_factor: float -) -> Union[Image.Image, np.ndarray]: + _data: Image.Image | np.ndarray, size_factor: float +) -> Image.Image | np.ndarray: assert size_factor <= 1, "Size factor must be less than 1" # Image input if isinstance(_data, Image.Image): @@ -87,13 +84,13 @@ def _resize_data( def resize_mm_data( - data: Union[ImageInput, VideoInput, AudioInput], size_factors: tuple[float, ...] -) -> Union[ImageInput, VideoInput, AudioInput]: + data: ImageInput | VideoInput | AudioInput, size_factors: tuple[float, ...] +) -> ImageInput | VideoInput | AudioInput: size_factors = size_factors[: len(data)] if is_list_of(data, (Image.Image, np.ndarray, list)): return [_resize_data(d, s) for d, s in zip(data, size_factors)] elif is_list_of(data, tuple): - return [(_resize_data(d, s), meta) for (d, meta), s in zip(data, size_factors)] + return [_resize_data(d, s) for (d, _), s in zip(data, size_factors)] raise ValueError("Unsupported multimodal data type.") @@ -103,6 +100,8 @@ def create_batched_mm_kwargs( processor: BaseMultiModalProcessor, size_factors: tuple[float, ...] = (1.0, 0.5, 0.25), ) -> Iterable[tuple[str, int, BatchedTensorInputs]]: + model_type = model_config.hf_config.model_type + processing_info = processor.info dummy_inputs = processor.dummy_inputs supported_mm_limits = processing_info.get_supported_mm_limits() @@ -113,32 +112,19 @@ def create_batched_mm_kwargs( processor_inputs = dummy_inputs.get_dummy_processor_inputs( seq_len=model_config.max_model_len, mm_counts=mm_counts, + mm_options=MM_OPTIONS_OVERRIDES.get(model_type), ) mm_data = processor_inputs.mm_data resized_mm_data = { modality: resize_mm_data(data, size_factors) for modality, data in mm_data.items() } - # Mistral chat outputs tokens directly, rather than text prompts - if model_config.tokenizer_mode == "mistral": - images = resized_mm_data.get("image", []) - request = ChatCompletionRequest( - messages=[ - UserMessage( - content=[ - TextChunk(text=""), - *(ImageChunk(image=image) for image in images), - ] - ), - ] - ) - tokenizer = processing_info.get_tokenizer() - res = tokenizer.mistral.encode_chat_completion(request) - prompt = res.tokens - else: - prompt = processor_inputs.prompt + + # video metadata will be added back to the resized video data here. + text_prompt, token_prompt = get_text_token_prompts(processor, resized_mm_data) + mm_kwargs = processor.apply( - prompt=prompt, + prompt=token_prompt if text_prompt is None else text_prompt, mm_data=resized_mm_data, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, tokenization_kwargs=processor_inputs.tokenization_kwargs, @@ -174,35 +160,15 @@ def initialize_dummy_model( cleanup_dist_env_and_memory() -def get_model_id_to_test(model_arch_list: Iterable[str]) -> list[tuple[str, str]]: - filtered_results = [] - for model_arch in model_arch_list: - model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) - if model_info.extras and model_arch in ARCH_NEEDS_EXTRAS: - available_repos = list( - map( - lambda model_id: (model_arch, model_id), - [model_info.default, *model_info.extras.values()], - ) - ) - filtered_results.extend(available_repos) - else: - filtered_results.append((model_arch, model_info.default)) - return filtered_results - - -@pytest.mark.parametrize( - "model_arch, model_id", get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys()) -) -def test_model_tensor_schema(model_arch: str, model_id: str): - if model_arch in ARCH_TO_SKIP: - pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}") - if model_id in REPO_ID_TO_SKIP: - pytest.skip(f"Skipping {model_id} due to {REPO_ID_TO_SKIP[model_id]}") - - model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) +@pytest.mark.parametrize("model_id", get_model_ids_to_test()) +def test_model_tensor_schema(model_id: str): + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info.check_available_online(on_fail="skip") - model_info.check_transformers_version(on_fail="skip", check_max_version=False) + model_info.check_transformers_version(on_fail="skip") + + model_arch = next( + arch for arch, info in HF_EXAMPLE_MODELS.hf_models.items() if info == model_info + ) hf_overrides_fn = partial( dummy_hf_overrides, @@ -217,7 +183,9 @@ def test_model_tensor_schema(model_arch: str, model_id: str): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=hf_overrides_fn, - skip_tokenizer_init=model_info.skip_tokenizer_init, + skip_tokenizer_init=model_info.require_embed_inputs, + enable_prompt_embeds=model_info.require_embed_inputs, + enable_mm_embeds=model_info.require_embed_inputs, enforce_eager=model_info.enforce_eager, dtype=model_info.dtype, ) diff --git a/tests/models/multimodal/test_mapping.py b/tests/models/multimodal/test_mapping.py index 2179cf33a573..2f38dc450ef9 100644 --- a/tests/models/multimodal/test_mapping.py +++ b/tests/models/multimodal/test_mapping.py @@ -59,7 +59,9 @@ def test_hf_model_weights_mapper(model_arch: str): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, - skip_tokenizer_init=model_info.skip_tokenizer_init, + skip_tokenizer_init=model_info.require_embed_inputs, + enable_prompt_embeds=model_info.require_embed_inputs, + enable_mm_embeds=model_info.require_embed_inputs, enforce_eager=model_info.enforce_eager, dtype=model_info.dtype, ) diff --git a/tests/models/quantization/test_awq.py b/tests/models/quantization/test_awq.py index c4c10832ede3..70464cf7fb41 100644 --- a/tests/models/quantization/test_awq.py +++ b/tests/models/quantization/test_awq.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import pytest import torch @@ -30,7 +29,7 @@ def run_awq_test( max_tokens: int, num_logprobs: int, tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, + distributed_executor_backend: str | None = None, ): images = [asset.pil_image for asset in image_assets] diff --git a/tests/models/quantization/test_fp8.py b/tests/models/quantization/test_fp8.py index 55b149ae5da7..2a6f34a9c482 100644 --- a/tests/models/quantization/test_fp8.py +++ b/tests/models/quantization/test_fp8.py @@ -9,9 +9,9 @@ import pytest from tests.quantization.utils import is_quant_method_supported +from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 from vllm.platforms import current_platform from vllm.utils import STR_BACKEND_ENV_VAR - from ..utils import check_logprobs_close @@ -69,8 +69,10 @@ def test_models( if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm(): pytest.skip(f"{kv_cache_dtype} is currently not supported on ROCm/HIP.") - if not current_platform.is_kv_cache_dtype_supported(kv_cache_dtype, None): - pytest.skip(f"{kv_cache_dtype} is not supported on this platform.") + if not flash_attn_supports_fp8(): + pytest.skip( + f"{kv_cache_dtype} is not supported on this GPU type with {backend} attention." + ) with monkeypatch.context() as m: m.setenv("TOKENIZERS_PARALLELISM", "true") diff --git a/tests/models/registry.py b/tests/models/registry.py index e7affb41565c..8e11ee755bf7 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -3,14 +3,13 @@ from collections.abc import Mapping, Set from dataclasses import dataclass, field -from typing import Any, Literal, Optional +from typing import Any, Literal import pytest -import torch from packaging.version import Version from transformers import __version__ as TRANSFORMERS_VERSION -from vllm.config import ModelDType, TokenizerMode +from vllm.config.model import ModelDType, TokenizerMode @dataclass(frozen=True) @@ -21,36 +20,42 @@ class _HfExamplesInfo: extras: Mapping[str, str] = field(default_factory=dict) """Extra models to use for testing this architecture.""" - tokenizer: Optional[str] = None + tokenizer: str | None = None """Set the tokenizer to load for this architecture.""" tokenizer_mode: TokenizerMode = "auto" """Set the tokenizer type for this architecture.""" - speculative_model: Optional[str] = None + speculative_model: str | None = None """ The default model to use for testing this architecture, which is only used for speculative decoding. """ - min_transformers_version: Optional[str] = None + speculative_method: str | None = None + """ + The method to use for speculative decoding. + """ + + min_transformers_version: str | None = None """ The minimum version of HF Transformers that is required to run this model. """ - max_transformers_version: Optional[str] = None + max_transformers_version: str | None = None """ The maximum version of HF Transformers that this model runs on. """ - transformers_version_reason: Optional[str] = None + transformers_version_reason: str | None = None """ The reason for the minimum/maximum version requirement. """ - skip_tokenizer_init: bool = False + require_embed_inputs: bool = False """ - If true, skip initialization of tokenizer and detokenizer. + If `True`, enables prompt and multi-modal embedding inputs while + disabling tokenization. """ dtype: ModelDType = "auto" @@ -67,34 +72,31 @@ class _HfExamplesInfo: is_available_online: bool = True """ - Set this to ``False`` if the name of this architecture no longer exists on + Set this to `False` if the name of this architecture no longer exists on the HF repo. To maintain backwards compatibility, we have not removed them from the main model registry, so without this flag the registry tests will fail. """ trust_remote_code: bool = False - """The ``trust_remote_code`` level required to load the model.""" - - v0_only: bool = False - """The model is only available with the vLLM V0 engine.""" + """The `trust_remote_code` level required to load the model.""" hf_overrides: dict[str, Any] = field(default_factory=dict) - """The ``hf_overrides`` required to load the model.""" + """The `hf_overrides` required to load the model.""" - max_model_len: Optional[int] = None + max_model_len: int | None = None """ The maximum model length to use for this model. Some models default to a length that is too large to fit into memory in CI. """ - revision: Optional[str] = None + revision: str | None = None """ The specific revision (commit hash, tag, or branch) to use for the model. If not specified, the default revision will be used. """ - max_num_seqs: Optional[int] = None + max_num_seqs: int | None = None """Maximum number of sequences to be processed in a single iteration.""" use_original_num_layers: bool = False @@ -109,7 +111,7 @@ def check_transformers_version( on_fail: Literal["error", "skip", "return"], check_min_version: bool = True, check_max_version: bool = True, - ) -> Optional[str]: + ) -> str | None: """ If the installed transformers version does not meet the requirements, perform the given action. @@ -171,11 +173,7 @@ def check_available_online( _TEXT_GENERATION_EXAMPLE_MODELS = { # [Decoder-only] - "ApertusForCausalLM": _HfExamplesInfo( - "swiss-ai/Apertus-8B-2509", - min_transformers_version="4.56.0", - trust_remote_code=True, - ), + "ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B-Instruct-2509"), "AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True), "AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True), "ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base"), @@ -196,7 +194,6 @@ def check_available_online( ), "BambaForCausalLM": _HfExamplesInfo( "ibm-ai-platform/Bamba-9B-v1", - min_transformers_version="4.55.3", extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}, ), "BloomForCausalLM": _HfExamplesInfo( @@ -216,11 +213,7 @@ def check_available_online( "CohereForAI/c4ai-command-r7b-12-2024", trust_remote_code=True, ), - "CwmForCausalLM": _HfExamplesInfo( - "facebook/cwm", - trust_remote_code=True, - is_available_online=False, - ), + "CwmForCausalLM": _HfExamplesInfo("facebook/cwm", min_transformers_version="4.58"), "DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"), "DeciLMForCausalLM": _HfExamplesInfo( "nvidia/Llama-3_3-Nemotron-Super-49B-v1", @@ -236,38 +229,30 @@ def check_available_online( trust_remote_code=True, ), "DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"), - "Ernie4_5ForCausalLM": _HfExamplesInfo( - "baidu/ERNIE-4.5-0.3B-PT", min_transformers_version="4.54" - ), - "Ernie4_5_MoeForCausalLM": _HfExamplesInfo( - "baidu/ERNIE-4.5-21B-A3B-PT", min_transformers_version="4.54" - ), + "Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT"), + "Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT"), "ExaoneForCausalLM": _HfExamplesInfo( "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", trust_remote_code=True ), - "Exaone4ForCausalLM": _HfExamplesInfo( - "LGAI-EXAONE/EXAONE-4.0-32B", min_transformers_version="4.54" - ), + "Exaone4ForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-4.0-32B"), "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"), + "FlexOlmoForCausalLM": _HfExamplesInfo("allenai/Flex-reddit-2x7B-1T"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), - "Gemma3nForCausalLM": _HfExamplesInfo( - "google/gemma-3n-E2B-it", min_transformers_version="4.53" - ), + "Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it"), "GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"), "Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"), - "Glm4MoeForCausalLM": _HfExamplesInfo( - "zai-org/GLM-4.5", min_transformers_version="4.54" - ), + "Glm4MoeForCausalLM": _HfExamplesInfo("zai-org/GLM-4.5"), "GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}), "GPTBigCodeForCausalLM": _HfExamplesInfo( "bigcode/starcoder", - extras={"tiny": "bigcode/tiny_starcoder_py"}, - min_transformers_version="4.55.1", - transformers_version_reason="HF model broken in 4.55.0", + extras={ + "tiny": "bigcode/tiny_starcoder_py", + "santacoder": "bigcode/gpt_bigcode-santacoder", + }, ), "GPTJForCausalLM": _HfExamplesInfo( "Milos/slovak-gpt-j-405M", {"6b": "EleutherAI/gpt-j-6b"} @@ -279,8 +264,7 @@ def check_available_online( "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), "GraniteMoeHybridForCausalLM": _HfExamplesInfo( - "ibm-granite/granite-4.0-tiny-preview", - min_transformers_version="4.55.3", + "ibm-granite/granite-4.0-tiny-preview" ), "GraniteMoeSharedForCausalLM": _HfExamplesInfo( "ibm-research/moe-7b-1b-active-shared-experts" @@ -288,15 +272,10 @@ def check_available_online( "Grok1ModelForCausalLM": _HfExamplesInfo( "hpcai-tech/grok-1", trust_remote_code=True ), + "HunYuanDenseV1ForCausalLM": _HfExamplesInfo("tencent/Hunyuan-7B-Instruct"), "HunYuanMoEV1ForCausalLM": _HfExamplesInfo( "tencent/Hunyuan-A13B-Instruct", trust_remote_code=True ), - # TODO: Remove is_available_online once their config.json is fixed - "HunYuanDenseV1ForCausalLM": _HfExamplesInfo( - "tencent/Hunyuan-7B-Instruct-0124", - trust_remote_code=True, - is_available_online=False, - ), "InternLMForCausalLM": _HfExamplesInfo( "internlm/internlm-chat-7b", trust_remote_code=True ), @@ -312,15 +291,12 @@ def check_available_online( "JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"), "JambaForCausalLM": _HfExamplesInfo( "ai21labs/AI21-Jamba-1.5-Mini", - min_transformers_version="4.55.3", extras={ "tiny": "ai21labs/Jamba-tiny-dev", "random": "ai21labs/Jamba-tiny-random", }, ), - "Lfm2ForCausalLM": _HfExamplesInfo( - "LiquidAI/LFM2-1.2B", min_transformers_version="4.54" - ), + "Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B"), "Lfm2MoeForCausalLM": _HfExamplesInfo( "LiquidAI/LFM2-8B-A1B", min_transformers_version="4.58" ), @@ -330,6 +306,7 @@ def check_available_online( "guard": "meta-llama/Llama-Guard-3-1B", "hermes": "NousResearch/Hermes-3-Llama-3.1-8B", "fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", + "tiny": "hmellor/tiny-random-LlamaForCausalLM", }, ), "LLaMAForCausalLM": _HfExamplesInfo( @@ -337,7 +314,6 @@ def check_available_online( ), "Llama4ForCausalLM": _HfExamplesInfo( "meta-llama/Llama-4-Scout-17B-16E-Instruct", - is_available_online=False, ), "LongcatFlashForCausalLM": _HfExamplesInfo( "meituan-longcat/LongCat-Flash-Chat", trust_remote_code=True @@ -345,7 +321,6 @@ def check_available_online( "MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"), "Mamba2ForCausalLM": _HfExamplesInfo( "mistralai/Mamba-Codestral-7B-v0.1", - min_transformers_version="4.55.3", extras={ "random": "yujiepan/mamba2-codestral-v0.1-tiny-random", }, @@ -420,7 +395,6 @@ def check_available_online( "SeedOssForCausalLM": _HfExamplesInfo( "ByteDance-Seed/Seed-OSS-36B-Instruct", trust_remote_code=True, - is_available_online=False, ), "SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"), "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), @@ -486,6 +460,10 @@ def check_available_online( "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), + "BertSpladeSparseEmbeddingModel": _HfExamplesInfo( + "naver/splade-v3", + hf_overrides={"architectures": ["BertSpladeSparseEmbeddingModel"]}, + ), # [Multimodal] "CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"), "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), @@ -493,20 +471,20 @@ def check_available_online( "TIGER-Lab/VLM2Vec-Full", trust_remote_code=True ), "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), + "SiglipModel": _HfExamplesInfo("google/siglip-base-patch16-224"), "PrithviGeoSpatialMAE": _HfExamplesInfo( "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", - dtype=torch.float16, + dtype="float16", enforce_eager=True, - skip_tokenizer_init=True, - # This is to avoid the model - # going OOM in CI + require_embed_inputs=True, + # This is to avoid the model going OOM in CI max_num_seqs=32, ), "Terratorch": _HfExamplesInfo( "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", - dtype=torch.float16, + dtype="float16", enforce_eager=True, - skip_tokenizer_init=True, + require_embed_inputs=True, # This is to avoid the model going OOM in CI max_num_seqs=32, ), @@ -562,6 +540,10 @@ def check_available_online( # [Decoder-only] "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"), "AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), + "BeeForConditionalGeneration": _HfExamplesInfo( + "Open-Bee/Bee-8B-RL", + trust_remote_code=True, + ), "Blip2ForConditionalGeneration": _HfExamplesInfo( "Salesforce/blip2-opt-2.7b", extras={"6b": "Salesforce/blip2-opt-6.7b"}, @@ -577,6 +559,9 @@ def check_available_online( transformers_version_reason="HF model is not compatible.", hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}, ), + "DeepseekOCRForCausalLM": _HfExamplesInfo( + "deepseek-ai/DeepSeek-OCR", + ), "DotsOCRForCausalLM": _HfExamplesInfo( "rednote-hilab/dots.ocr", trust_remote_code=True ), @@ -587,10 +572,7 @@ def check_available_online( ), "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), - "Gemma3nForConditionalGeneration": _HfExamplesInfo( - "google/gemma-3n-E2B-it", - min_transformers_version="4.53", - ), + "Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it"), "GraniteSpeechForConditionalGeneration": _HfExamplesInfo( "ibm-granite/granite-speech-3.3-2b" ), @@ -600,9 +582,7 @@ def check_available_online( hf_overrides={"architectures": ["GLM4VForCausalLM"]}, ), "Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"), - "Glm4vMoeForConditionalGeneration": _HfExamplesInfo( - "zai-org/GLM-4.5V", min_transformers_version="4.56" - ), + "Glm4vMoeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V"), "H2OVLChatModel": _HfExamplesInfo( "h2oai/h2ovl-mississippi-800m", trust_remote_code=True, @@ -616,9 +596,7 @@ def check_available_online( ), "Idefics3ForConditionalGeneration": _HfExamplesInfo( "HuggingFaceM4/Idefics3-8B-Llama3", - {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}, - min_transformers_version="4.56", - transformers_version_reason="HF model broken in 4.55", + extras={"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}, ), "InternS1ForConditionalGeneration": _HfExamplesInfo( "internlm/Intern-S1", trust_remote_code=True @@ -648,6 +626,10 @@ def check_available_online( extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, trust_remote_code=True, ), + "LightOnOCRForConditionalGeneration": _HfExamplesInfo( + "lightonai/LightOnOCR-1B", + is_available_online=False, + ), "Llama4ForConditionalGeneration": _HfExamplesInfo( "meta-llama/Llama-4-Scout-17B-16E-Instruct", max_model_len=10240, @@ -691,7 +673,6 @@ def check_available_online( "MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo( "MiniMaxAI/MiniMax-VL-01", trust_remote_code=True, - v0_only=True, ), "Mistral3ForConditionalGeneration": _HfExamplesInfo( "mistralai/Mistral-Small-3.1-24B-Instruct-2503", @@ -749,6 +730,8 @@ def check_available_online( "Qwen/Qwen-VL", extras={"chat": "Qwen/Qwen-VL-Chat"}, trust_remote_code=True, + max_transformers_version="4.53.3", + transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501 hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}, ), "Qwen2AudioForConditionalGeneration": _HfExamplesInfo( @@ -765,22 +748,23 @@ def check_available_online( "Qwen/Qwen3-VL-4B-Instruct", max_model_len=4096, min_transformers_version="4.57", - is_available_online=False, ), "Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo( "Qwen/Qwen3-VL-30B-A3B-Instruct", max_model_len=4096, min_transformers_version="4.57", - is_available_online=False, + ), + "Qwen3OmniMoeForConditionalGeneration": _HfExamplesInfo( + "Qwen/Qwen3-Omni-30B-A3B-Instruct", + max_model_len=4096, + min_transformers_version="4.57", ), "RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", trust_remote_code=True), "SkyworkR1VChatModel": _HfExamplesInfo( "Skywork/Skywork-R1V-38B", trust_remote_code=True ), "SmolVLMForConditionalGeneration": _HfExamplesInfo( - "HuggingFaceTB/SmolVLM2-2.2B-Instruct", - min_transformers_version="4.56", - transformers_version_reason="HF model broken in 4.55", + "HuggingFaceTB/SmolVLM2-2.2B-Instruct" ), "Step3VLForConditionalGeneration": _HfExamplesInfo( "stepfun-ai/step3", trust_remote_code=True @@ -796,7 +780,6 @@ def check_available_online( ), "VoxtralForConditionalGeneration": _HfExamplesInfo( "mistralai/Voxtral-Mini-3B-2507", - min_transformers_version="4.54", # disable this temporarily until we support HF format is_available_online=False, ), @@ -857,8 +840,8 @@ def check_available_online( "EagleMiniCPMForCausalLM": _HfExamplesInfo( "openbmb/MiniCPM-1B-sft-bf16", trust_remote_code=True, - is_available_online=False, speculative_model="openbmb/MiniCPM-2B-sft-bf16", + speculative_method="eagle", tokenizer="openbmb/MiniCPM-2B-sft-bf16", ), "ErnieMTPModel": _HfExamplesInfo( @@ -869,8 +852,6 @@ def check_available_online( "Glm4MoeMTPModel": _HfExamplesInfo( "zai-org/GLM-4.5", speculative_model="zai-org/GLM-4.5", - min_transformers_version="4.56", - is_available_online=False, ), "LongCatFlashMTPModel": _HfExamplesInfo( "meituan-longcat/LongCat-Flash-Chat", @@ -902,11 +883,11 @@ def check_available_online( "TransformersForCausalLM": _HfExamplesInfo( "hmellor/Ilama-3.2-1B", trust_remote_code=True ), - "TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), + "TransformersMultiModalForCausalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), "TransformersMoEForCausalLM": _HfExamplesInfo( "allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0" ), - "TransformersMoEForMultimodalLM": _HfExamplesInfo( + "TransformersMultiModalMoEForCausalLM": _HfExamplesInfo( "Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0" ), "TransformersMoEEmbeddingModel": _HfExamplesInfo( @@ -915,6 +896,10 @@ def check_available_online( "TransformersMoEForSequenceClassification": _HfExamplesInfo( "Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0" ), + "TransformersMultiModalEmbeddingModel": _HfExamplesInfo("google/gemma-3-4b-it"), + "TransformersMultiModalForSequenceClassification": _HfExamplesInfo( + "google/gemma-3-4b-it" + ), } _EXAMPLE_MODELS = { diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index f501798ffa36..48a6f34366cf 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -7,7 +7,7 @@ import pytest from vllm import LLM -from vllm.utils import GiB_bytes +from vllm.utils.mem_constants import GiB_bytes from vllm.v1.core.kv_cache_utils import ( generate_scheduler_kv_cache_config, get_kv_cache_configs, @@ -37,7 +37,7 @@ "JinaVLForRanking", "InternVLChatModel", "InternLM2ForRewardModel", - "TransformersForMultimodalLM", + "TransformersMultiModalForCausalLM", "PrithviGeoSpatialMAE", "UltravoxModel", "DeepSeekMTPModel", @@ -88,13 +88,15 @@ def _initialize_kv_caches_v1(self, vllm_config): # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config return 1, 0, scheduler_kv_cache_config + if model_arch == "MiniMaxVL01ForConditionalGeneration": + pytest.skip( + "pickle error when loading `transformers.models.auto.CONFIG_MAPPING`" + ) + with ( patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1), monkeypatch.context() as m, ): - if model_info.v0_only: - # NOTE(woosuk): skip the test for V0-only models - return if model_arch == "GptOssForCausalLM": # FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when @@ -102,16 +104,20 @@ def _initialize_kv_caches_v1(self, vllm_config): m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") if model_arch == "WhisperForConditionalGeneration": m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + LLM( model_info.default, tokenizer=model_info.tokenizer, tokenizer_mode=model_info.tokenizer_mode, revision=model_info.revision, enforce_eager=model_info.enforce_eager, - skip_tokenizer_init=model_info.skip_tokenizer_init, + skip_tokenizer_init=model_info.require_embed_inputs, + enable_prompt_embeds=model_info.require_embed_inputs, + enable_mm_embeds=model_info.require_embed_inputs, dtype=model_info.dtype, speculative_config={ "model": model_info.speculative_model, + "method": model_info.speculative_method, "num_speculative_tokens": 1, } if model_info.speculative_model @@ -132,8 +138,6 @@ def _initialize_kv_caches_v1(self, vllm_config): @pytest.mark.parametrize("model_arch", MINIMAL_MODEL_ARCH_LIST) def test_can_initialize_small_subset(model_arch: str, monkeypatch: pytest.MonkeyPatch): """Test initializing small subset of supported models""" - if model_arch == "Lfm2ForCausalLM": - pytest.skip("Skipping until test supports V1-only models") can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS) @@ -144,8 +148,6 @@ def test_can_initialize_large_subset(model_arch: str, monkeypatch: pytest.Monkey This test covers the complement of the tests covered in the "small subset" test. """ - if model_arch == "Lfm2ForCausalLM": - pytest.skip("Skipping until test supports V1-only models") can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS) diff --git a/tests/models/test_terratorch.py b/tests/models/test_terratorch.py index cadce5d2b2bb..15764145bc1a 100644 --- a/tests/models/test_terratorch.py +++ b/tests/models/test_terratorch.py @@ -32,6 +32,7 @@ def test_inference( dtype="half", enforce_eager=True, skip_tokenizer_init=True, + enable_mm_embeds=True, # Limit the maximum number of sequences to avoid the # test going OOM during the warmup run max_num_seqs=32, diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index b434c0955be7..d8a1aace8332 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Test the functionality of the Transformers backend.""" -from typing import Any, Optional, Union +from typing import Any import pytest @@ -21,12 +21,12 @@ def get_model(arch: str) -> str: def check_implementation( - runner_ref: type[Union[HfRunner, VllmRunner]], + runner_ref: type[HfRunner | VllmRunner], runner_test: type[VllmRunner], example_prompts: list[str], model: str, - kwargs_ref: Optional[dict[str, Any]] = None, - kwargs_test: Optional[dict[str, Any]] = None, + kwargs_ref: dict[str, Any] | None = None, + kwargs_test: dict[str, Any] | None = None, **kwargs, ): if kwargs_ref is None: @@ -211,11 +211,7 @@ def test_embed_loading(vllm_runner, model): def test_pooling(hf_runner, vllm_runner, example_prompts, arch): model = get_model(arch) - vllm_kwargs = dict( - max_model_len=None, - model_impl="transformers", - compilation_config=dict(cudagraph_capture_sizes=[8]), - ) + vllm_kwargs = dict(max_model_len=None, model_impl="transformers") hf_kwargs = dict() if arch == "TransformersEmbeddingModel": diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py index b323bca79f4e..82ba958a58c4 100644 --- a/tests/models/test_vision.py +++ b/tests/models/test_vision.py @@ -19,7 +19,8 @@ run_dp_sharded_vision_model, ) from vllm.platforms import current_platform -from vllm.utils import get_open_port, update_environment_variables +from vllm.utils.network_utils import get_open_port +from vllm.utils.system_utils import update_environment_variables pytestmark = pytest.mark.cpu_test diff --git a/tests/models/utils.py b/tests/models/utils.py index c20e50ff1bff..9843887a1320 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -4,17 +4,18 @@ import warnings from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any import torch import torch.nn.functional as F from transformers import PretrainedConfig -from vllm.config import ModelConfig, ModelDType, RunnerOption +from vllm.config.model import ModelConfig, ModelDType, RunnerOption from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs from vllm.multimodal.processing import InputProcessingContext from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from .. import ci_envs from .registry import HF_EXAMPLE_MODELS TokensText = tuple[list[int], str] @@ -57,7 +58,7 @@ def check_outputs_equal( # # Assumes prompt logprobs were not requested. TokensTextLogprobs = tuple[ - list[int], str, Optional[Union[list[dict[int, float]], SampleLogprobs]] + list[int], str, list[dict[int, float]] | SampleLogprobs | None ] # Allow for tokens to be represented as str's rather than IDs; @@ -68,7 +69,7 @@ def check_outputs_equal( # # Assumes prompt logprobs were not requested. TextTextLogprobs = tuple[ - list[str], str, Optional[Union[list[dict[str, float]], list[dict[str, Logprob]]]] + list[str], str, list[dict[str, float]] | list[dict[str, Logprob]] | None ] # Representation of generated sequence as a tuple of @@ -81,18 +82,18 @@ def check_outputs_equal( TokensTextLogprobsPromptLogprobs = tuple[ list[int], str, - Optional[Union[list[dict[int, float]], SampleLogprobs]], - Optional[Union[list[Optional[dict[int, float]]], PromptLogprobs]], + list[dict[int, float]] | SampleLogprobs | None, + list[dict[int, float] | None] | PromptLogprobs | None, ] def check_logprobs_close( *, outputs_0_lst: Sequence[ - Union[TokensTextLogprobs, TokensTextLogprobsPromptLogprobs, TextTextLogprobs] + TokensTextLogprobs | TokensTextLogprobsPromptLogprobs | TextTextLogprobs ], outputs_1_lst: Sequence[ - Union[TokensTextLogprobs, TokensTextLogprobsPromptLogprobs, TextTextLogprobs] + TokensTextLogprobs | TokensTextLogprobsPromptLogprobs | TextTextLogprobs ], name_0: str, name_1: str, @@ -161,7 +162,7 @@ def check_logprobs_close( # Test prompt logprobs closeness if prompt_logprobs_0 is not None and prompt_logprobs_1 is not None: - # Both sequences' prompt logprobs lists are not `None`` + # Both sequences' prompt logprobs lists are not `None` # (although individual list elements may be `None`); # for each token's logprobs: for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate( @@ -273,9 +274,9 @@ def build_model_context( model_id: str, runner: RunnerOption = "auto", dtype: ModelDType = "auto", - model_config_kwargs: Optional[dict[str, Any]] = None, - mm_processor_kwargs: Optional[dict[str, Any]] = None, - limit_mm_per_prompt: Optional[dict[str, int]] = None, + model_config_kwargs: dict[str, Any] | None = None, + mm_processor_kwargs: dict[str, Any] | None = None, + limit_mm_per_prompt: dict[str, int] | None = None, mm_processor_cache_gb: int = 0, ): """Creates an InputProcessingContext for a given model. @@ -308,7 +309,9 @@ def build_model_context( limit_mm_per_prompt=limit_mm_per_prompt, mm_processor_cache_gb=mm_processor_cache_gb, hf_overrides=model_info.hf_overrides, - skip_tokenizer_init=model_info.skip_tokenizer_init, + skip_tokenizer_init=model_info.require_embed_inputs, + enable_prompt_embeds=model_info.require_embed_inputs, + enable_mm_embeds=model_info.require_embed_inputs, enforce_eager=model_info.enforce_eager, **model_config_kwargs, ) @@ -369,17 +372,18 @@ class ModelInfo: name: str architecture: str = "" dtype: str = "auto" + max_model_len: int | None = None hf_dtype: str = "float32" - hf_overrides: Optional[dict[str, Any]] = None + hf_overrides: dict[str, Any] | None = None default_pooling_type: str = "" enable_test: bool = True @dataclass class EmbedModelInfo(ModelInfo): - mteb_score: Optional[float] = None + mteb_score: float | None = None is_matryoshka: bool = False - matryoshka_dimensions: Optional[list[int]] = None + matryoshka_dimensions: list[int] | None = None @dataclass @@ -394,7 +398,7 @@ class LASTPoolingEmbedModelInfo(EmbedModelInfo): @dataclass class RerankModelInfo(ModelInfo): - mteb_score: Optional[float] = None + mteb_score: float | None = None @dataclass @@ -410,14 +414,43 @@ class LASTPoolingRerankModelInfo(RerankModelInfo): @dataclass class GenerateModelInfo(ModelInfo): hf_dtype: str = "auto" - hf_ppl: Optional[float] = None + hf_ppl: float | None = None + + +def get_vllm_extra_kwargs(model_info: ModelInfo, vllm_extra_kwargs): + # A model family has many models with the same architecture, + # and we don't need to test each one. + if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: + import pytest + + pytest.skip("Skipping test.") + + # Allow vllm to test using the given dtype, such as float32 + vllm_extra_kwargs = vllm_extra_kwargs or {} + vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype + + # Allow vllm to test using hf_overrides + if model_info.hf_overrides is not None: + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + + # Allow changing the head dtype used by vllm in tests + if ci_envs.VLLM_CI_HEAD_DTYPE is not None: + if "hf_overrides" not in vllm_extra_kwargs: + vllm_extra_kwargs["hf_overrides"] = {} + vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + + # Allow control over whether tests use enforce_eager + if ci_envs.VLLM_CI_ENFORCE_EAGER is not None: + vllm_extra_kwargs["enforce_eager"] = ci_envs.VLLM_CI_ENFORCE_EAGER + + return vllm_extra_kwargs def dummy_hf_overrides( hf_config: PretrainedConfig, *, model_arch: str = "", - exist_overrides: Optional[dict[str, Any]] = None, + exist_overrides: dict[str, Any] | None = None, use_original_num_layers: bool = False, ) -> PretrainedConfig: """ @@ -506,8 +539,8 @@ class DummyConfig: def check_transformers_version( model: str, - min_transformers_version: Optional[str] = None, - max_transformers_version: Optional[str] = None, + min_transformers_version: str | None = None, + max_transformers_version: str | None = None, ): from .registry import _HfExamplesInfo diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py index fe983990b90c..531674c30f55 100644 --- a/tests/multimodal/test_cache.py +++ b/tests/multimodal/test_cache.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import numpy as np import pytest @@ -32,7 +31,7 @@ def _dummy_elem( key: str, size: int, *, - rng: Optional[np.random.RandomState] = None, + rng: np.random.RandomState | None = None, ): if rng is None: data = torch.empty((size,), dtype=torch.int8) @@ -51,7 +50,7 @@ def _dummy_item( modality: str, size_by_key: dict[str, int], *, - rng: Optional[np.random.RandomState] = None, + rng: np.random.RandomState | None = None, ): return MultiModalKwargsItem.from_elems( [_dummy_elem(modality, key, size, rng=rng) for key, size in size_by_key.items()] @@ -61,7 +60,7 @@ def _dummy_item( def _dummy_items( size_by_key_modality: dict[str, dict[str, int]], *, - rng: Optional[np.random.RandomState] = None, + rng: np.random.RandomState | None = None, ): return MultiModalKwargsItems.from_seq( [ diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index a542b068a42b..2f04bc6695c8 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import nullcontext -from typing import Optional, cast +from typing import cast import numpy as np import pytest @@ -1003,7 +1003,7 @@ def __call__( self, a: int = 0, c: int = 0, - return_tensors: Optional[str] = None, + return_tensors: str | None = None, ) -> dict[str, int]: return dict(a=a, c=c) diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py index a2a8d0ec9aba..5614f19d1a4f 100644 --- a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py @@ -1,15 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import base64 import datetime import os import tempfile import urllib.request from collections.abc import Sequence -from typing import Any, Union +from typing import Any import albumentations import numpy as np @@ -160,11 +158,11 @@ def read_geotiff( def load_image( - data: Union[list[str]], + data: list[str], path_type: str, mean: list[float] | None = None, std: list[float] | None = None, - indices: Union[list[int], None] | None = None, + indices: list[int] | None | None = None, ): """Build an input example by loading images in *file_paths*. @@ -280,7 +278,7 @@ def pre_process( prompt: IOProcessorInput, request_id: str | None = None, **kwargs, - ) -> Union[PromptType, Sequence[PromptType]]: + ) -> PromptType | Sequence[PromptType]: image_data = dict(prompt) if request_id: @@ -370,9 +368,9 @@ def post_process( out_format = "b64_json" for output in model_output: - y_hat = output.outputs.data.argmax(dim=1) + y_hat = output.outputs.data.argmax(dim=0) pred = torch.nn.functional.interpolate( - y_hat.unsqueeze(1).float(), + y_hat[None, None, ...].float(), size=self.img_size, mode="nearest", ) diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py index 21a5c3754c36..d1d7873211f2 100644 --- a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Any, Literal, TypedDict import albumentations from pydantic import BaseModel @@ -38,7 +38,7 @@ class ImagePrompt(BaseModel): """ -MultiModalPromptType = Union[ImagePrompt] +MultiModalPromptType = ImagePrompt class ImageRequestOutput(BaseModel): @@ -54,4 +54,4 @@ class ImageRequestOutput(BaseModel): type: Literal["path", "b64_json"] format: str data: str - request_id: Optional[str] = None + request_id: str | None = None diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index a22a10eab47d..98245cdf0c98 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Optional, Union import torch import torch.nn as nn @@ -31,7 +30,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_embed": Pooler.for_token_embed(pooler_config), "embed": Pooler.for_embed(pooler_config), } ) @@ -44,9 +43,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index 9e6f5c3a77e3..79af3ad842f5 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -20,7 +19,7 @@ dummy_inputs=LlavaDummyInputsBuilder, ) class MyLlava(LlavaForConditionalGeneration): - def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: # this dummy model always predicts the first token logits = super().compute_logits(hidden_states) if logits is not None: diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py index c02299f5d44f..f1e6e7b10f8b 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -9,7 +8,7 @@ class MyOPTForCausalLM(OPTForCausalLM): - def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: # this dummy model always predicts the first token logits = super().compute_logits(hidden_states) if logits is not None: diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py index c4fe6ed197f6..280b68514e19 100644 --- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py @@ -1,10 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional - -def dummy_platform_plugin() -> Optional[str]: +def dummy_platform_plugin() -> str | None: return "vllm_add_dummy_platform.dummy_platform.DummyPlatform" diff --git a/tests/plugins/vllm_add_dummy_stat_logger/dummy_stat_logger/dummy_stat_logger.py b/tests/plugins/vllm_add_dummy_stat_logger/dummy_stat_logger/dummy_stat_logger.py new file mode 100644 index 000000000000..66ec35c0d5c9 --- /dev/null +++ b/tests/plugins/vllm_add_dummy_stat_logger/dummy_stat_logger/dummy_stat_logger.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.v1.metrics.loggers import StatLoggerBase + + +class DummyStatLogger(StatLoggerBase): + """ + A dummy stat logger for testing purposes. + Implements the minimal interface expected by StatLoggerManager. + """ + + def __init__(self, vllm_config, engine_idx=0): + self.vllm_config = vllm_config + self.engine_idx = engine_idx + self.recorded = [] + self.logged = False + self.engine_initialized = False + + def record(self, scheduler_stats, iteration_stats, mm_cache_stats, engine_idx): + self.recorded.append( + (scheduler_stats, iteration_stats, mm_cache_stats, engine_idx) + ) + + def log(self): + self.logged = True + + def log_engine_initialized(self): + self.engine_initialized = True diff --git a/tests/plugins/vllm_add_dummy_stat_logger/setup.py b/tests/plugins/vllm_add_dummy_stat_logger/setup.py new file mode 100644 index 000000000000..517017724bcc --- /dev/null +++ b/tests/plugins/vllm_add_dummy_stat_logger/setup.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from setuptools import setup + +setup( + name="dummy_stat_logger", + version="0.1", + packages=["dummy_stat_logger"], + entry_points={ + "vllm.stat_logger_plugins": [ + "dummy_stat_logger = dummy_stat_logger.dummy_stat_logger:DummyStatLogger" # noqa + ] + }, +) diff --git a/tests/plugins_tests/test_io_processor_plugins.py b/tests/plugins_tests/test_io_processor_plugins.py index 912b32755e80..582cf9a0711b 100644 --- a/tests/plugins_tests/test_io_processor_plugins.py +++ b/tests/plugins_tests/test_io_processor_plugins.py @@ -9,7 +9,6 @@ from vllm.config import VllmConfig from vllm.entrypoints.openai.protocol import IOProcessorResponse from vllm.plugins.io_processors import get_io_processor -from vllm.pooling_params import PoolingParams MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" @@ -38,6 +37,7 @@ def server(): "prithvi_to_tiff", "--model-impl", "terratorch", + "--enable-mm-embeds", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -93,12 +93,11 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): out_data_format="b64_json", ) - pooling_params = PoolingParams(task="encode", softmax=False) - with vllm_runner( model_name, runner="pooling", skip_tokenizer_init=True, + enable_mm_embeds=True, trust_remote_code=True, enforce_eager=True, # Limit the maximum number of parallel requests @@ -107,10 +106,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): model_impl="terratorch", io_processor_plugin="prithvi_to_tiff", ) as llm_runner: - pooler_output = llm_runner.get_llm().encode( - img_prompt, - pooling_params=pooling_params, - ) + pooler_output = llm_runner.get_llm().encode(img_prompt, pooling_task="plugin") output = pooler_output[0].outputs # verify the output is formatted as expected for this plugin diff --git a/tests/plugins_tests/test_stats_logger_plugins.py b/tests/plugins_tests/test_stats_logger_plugins.py new file mode 100644 index 000000000000..eb03b1fde417 --- /dev/null +++ b/tests/plugins_tests/test_stats_logger_plugins.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from dummy_stat_logger.dummy_stat_logger import DummyStatLogger + +from vllm.config import VllmConfig +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.metrics.loggers import load_stat_logger_plugin_factories + + +def test_stat_logger_plugin_is_discovered(monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m: + m.setenv("VLLM_PLUGINS", "dummy_stat_logger") + + factories = load_stat_logger_plugin_factories() + assert len(factories) == 1, f"Expected 1 factory, got {len(factories)}" + assert factories[0] is DummyStatLogger, ( + f"Expected DummyStatLogger class, got {factories[0]}" + ) + + # instantiate and confirm the right type + vllm_config = VllmConfig() + instance = factories[0](vllm_config) + assert isinstance(instance, DummyStatLogger) + + +def test_no_plugins_loaded_if_env_empty(monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m: + m.setenv("VLLM_PLUGINS", "") + + factories = load_stat_logger_plugin_factories() + assert factories == [] + + +def test_invalid_stat_logger_plugin_raises(monkeypatch: pytest.MonkeyPatch): + def fake_plugin_loader(group: str): + assert group == "vllm.stat_logger_plugins" + return {"bad": object()} + + with monkeypatch.context() as m: + m.setattr( + "vllm.v1.metrics.loggers.load_plugins_by_group", + fake_plugin_loader, + ) + with pytest.raises( + TypeError, + match="Stat logger plugin 'bad' must be a subclass of StatLoggerBase", + ): + load_stat_logger_plugin_factories() + + +@pytest.mark.asyncio +async def test_stat_logger_plugin_integration_with_engine( + monkeypatch: pytest.MonkeyPatch, +): + with monkeypatch.context() as m: + m.setenv("VLLM_PLUGINS", "dummy_stat_logger") + + engine_args = AsyncEngineArgs( + model="facebook/opt-125m", + enforce_eager=True, # reduce test time + disable_log_stats=True, # disable default loggers + ) + + engine = AsyncLLM.from_engine_args(engine_args=engine_args) + + assert len(engine.logger_manager.stat_loggers) == 2 + assert len(engine.logger_manager.stat_loggers[0].per_engine_stat_loggers) == 1 + assert isinstance( + engine.logger_manager.stat_loggers[0].per_engine_stat_loggers[0], + DummyStatLogger, + ) + + engine.shutdown() diff --git a/tests/quantization/fp_quant.py b/tests/quantization/fp_quant.py new file mode 100644 index 000000000000..664ce9d111e4 --- /dev/null +++ b/tests/quantization/fp_quant.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Test model set-up and inference for quantized HF models supported +on the GPU backend using FPQuant. + +Validating the configuration and printing results for manual checking. + +Run `pytest tests/quantization/test_fp_quant.py`. +""" + +import pytest + +from tests.quantization.utils import is_quant_method_supported + +MODELS = [ + "ISTA-DASLab/Qwen3-0.6B-RTN-NVFP4", + "ISTA-DASLab/Qwen3-0.6B-RTN-MXFP4", +] +DTYPE = ["bfloat16"] +EAGER = [True, False] + + +@pytest.mark.skipif( + not is_quant_method_supported("fp_quant"), + reason="FPQuant is not supported on this GPU type.", +) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("eager", EAGER) +def test_fpquant(vllm_runner, model, eager): + with vllm_runner(model, enforce_eager=eager) as llm: + output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2) + assert output[0][1] == "1 2 3 4 5 6" diff --git a/tests/quantization/test_auto_round.py b/tests/quantization/test_auto_round.py index 69632ae6cac7..9f5db8219501 100644 --- a/tests/quantization/test_auto_round.py +++ b/tests/quantization/test_auto_round.py @@ -26,7 +26,7 @@ ) @pytest.mark.parametrize("model", MODELS) def test_auto_round(vllm_runner, model): - with vllm_runner(model) as llm: + with vllm_runner(model, enforce_eager=True) as llm: output = llm.generate_greedy(["The capital of France is"], max_tokens=8) assert output print(f"{output[0][1]}") diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py index 4a0f701ae3cb..3cae6f46147b 100644 --- a/tests/quantization/test_blackwell_moe.py +++ b/tests/quantization/test_blackwell_moe.py @@ -3,7 +3,7 @@ import json import os -from typing import Optional +from typing import Any import pytest @@ -25,12 +25,21 @@ def set_test_environment(): os.environ["FLASHINFER_NVCC_THREADS"] = "16" -# dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4, -# "text_config": {"num_layers": 4, "num_hidden_layers": 4}} -dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4} +# Overide the backbone layers to 4 for faster startup +HF_OVERRIDE_TEXT = { + "num_layers": 4, + "num_hidden_layers": 4, +} +HF_OVERRIDE_MM = { + "text_config": {"num_layers": 4, "num_hidden_layers": 4}, +} -def can_initialize(model: str, extra_args: Optional[list[str]] = None): +def can_initialize( + model: str, + hf_overrides: dict[str, Any] | None = None, + extra_args: list[str] | None = None, +): # Server arguments extra_args = extra_args if extra_args is not None else [] server_args = [ @@ -50,8 +59,8 @@ def can_initialize(model: str, extra_args: Optional[list[str]] = None): with RemoteOpenAIServer( model, server_args, - max_wait_seconds=1000, # Due to FlashInfer compile - override_hf_configs=dummy_hf_overrides, + max_wait_seconds=1500, # Due to FlashInfer compile + override_hf_configs=hf_overrides, ) as server: client = server.get_client() # Make a simple request to verify the server works @@ -78,28 +87,33 @@ def can_initialize(model: str, extra_args: Optional[list[str]] = None): def test_llama4_fp8_tensor_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") - can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8") + can_initialize( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", hf_overrides=HF_OVERRIDE_MM + ) -@pytest.mark.skip(reason="Works, but takes too long to run") def test_llama4_fp8_tensor_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") - can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8") + can_initialize( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", hf_overrides=HF_OVERRIDE_MM + ) -@pytest.mark.skip(reason="Works, but takes too long to run") def test_llama4_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") - can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4") + can_initialize( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", hf_overrides=HF_OVERRIDE_MM + ) -@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options") def test_llama4_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") - can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4") + can_initialize( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", hf_overrides=HF_OVERRIDE_MM + ) ## DeepSeekV3 ## @@ -107,7 +121,7 @@ def test_llama4_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1") - can_initialize("deepseek-ai/DeepSeek-V3.1") + can_initialize("deepseek-ai/DeepSeek-V3.1", hf_overrides=HF_OVERRIDE_TEXT) @pytest.mark.skip( @@ -119,26 +133,25 @@ def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch): def test_deepseek_fp8_block_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") - can_initialize("deepseek-ai/DeepSeek-V3.1") + can_initialize("deepseek-ai/DeepSeek-V3.1", hf_overrides=HF_OVERRIDE_TEXT) def test_deepseek_fp8_block_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") - can_initialize("deepseek-ai/DeepSeek-V3.1") + can_initialize("deepseek-ai/DeepSeek-V3.1", hf_overrides=HF_OVERRIDE_TEXT) def test_deepseek_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") - can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2") + can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", hf_overrides=HF_OVERRIDE_TEXT) -@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options") def test_deepseek_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") - can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2") + can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", hf_overrides=HF_OVERRIDE_TEXT) ## GPT-OSS ## @@ -146,14 +159,34 @@ def test_deepseek_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): def test_gptoss_mxfp4bf16_moe_flashinfer(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1") - can_initialize("openai/gpt-oss-20b") + can_initialize("openai/gpt-oss-20b", hf_overrides=HF_OVERRIDE_TEXT) def test_gptoss_mxfp4mxfp8_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "1") - can_initialize("openai/gpt-oss-20b") + can_initialize("openai/gpt-oss-20b", hf_overrides=HF_OVERRIDE_TEXT) def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1") - can_initialize("openai/gpt-oss-20b") + can_initialize("openai/gpt-oss-20b", hf_overrides=HF_OVERRIDE_TEXT) + + +def test_gptoss_dp2_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1") + monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput") + can_initialize( + "openai/gpt-oss-20b", + extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"], + hf_overrides=HF_OVERRIDE_TEXT, + ) + + +def test_gptoss_dp2_mxfp4bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1") + monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput") + can_initialize( + "openai/gpt-oss-20b", + extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"], + hf_overrides=HF_OVERRIDE_TEXT, + ) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 824d927724e0..e7d902ed26aa 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -5,8 +5,6 @@ Run `pytest tests/quantization/test_compressed_tensors.py`. """ -from typing import Optional - import pytest import torch from compressed_tensors.quantization import QuantizationType @@ -68,13 +66,6 @@ def enable_pickle(monkeypatch): 2560, True, ), - ( - "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", - "channel", - QuantizationType.INT, - 2560, - True, - ), ( "nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", "tensor", @@ -104,7 +95,7 @@ def check_model(model): down_proj = layer.mlp.down_proj # assert zp for symmetric and asymmetric cases - def zp_valid(zp: Optional[torch.Tensor]): + def zp_valid(zp: torch.Tensor | None): if is_symmetric: return zp is None @@ -140,7 +131,7 @@ def zp_valid(zp: Optional[torch.Tensor]): llm.apply_model(check_model) - output = llm.generate_greedy(["Hello my name is"], max_tokens=20) + output = llm.generate_greedy(["Hello my name is"], max_tokens=4) assert output @@ -148,12 +139,9 @@ def zp_valid(zp: Optional[torch.Tensor]): "model_path", [ "neuralmagic/Llama-3.2-1B-quantized.w8a8", - "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym", - "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym", - "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym", ], ) -@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("max_tokens", [8]) @pytest.mark.parametrize("num_logprobs", [10]) @pytest.mark.parametrize( "use_aiter", [True, False] if current_platform.is_rocm() else [False] @@ -213,7 +201,7 @@ def test_compressed_tensors_w8a8_logprobs( def test_compressed_tensors_no_enforce_eager(vllm_runner): model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change" with vllm_runner(model_path) as llm: - output = llm.generate_greedy("Hello my name is", max_tokens=20) + output = llm.generate_greedy("Hello my name is", max_tokens=4) assert output @@ -221,15 +209,10 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner): "model_args", [ ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"), - ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"), ( "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel", ), - ( - "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym", - "channel", - ), ], ) @pytest.mark.parametrize( @@ -255,7 +238,7 @@ def test_compressed_tensors_w8a8_dynamic_per_token( # this will enable VLLM_ROCM_USE_AITER_LINEAR monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - with vllm_runner(model_path, dtype=torch.float16) as llm: + with vllm_runner(model_path, enforce_eager=True, dtype=torch.float16) as llm: def check_model(model): layer = model.model.layers[0] @@ -270,7 +253,7 @@ def check_model(model): llm.apply_model(check_model) - output = llm.generate_greedy(["Hello my name is"], max_tokens=20) + output = llm.generate_greedy(["Hello my name is"], max_tokens=4) assert output @@ -285,38 +268,6 @@ def check_model(model): True, False, ), - ( - "nm-testing/tinyllama-oneshot-w4a16-group128-v2", - "group", - 128, - 8, - True, - False, - ), - ( - "nm-testing/tinyllama-oneshot-w8a16-per-channel", - "channel", - None, - 4, - True, - False, - ), - ( - "nm-testing/TinyLlama-1.1B-Chat-v1.0-awq-group128-asym256", - "group", - 128, - 8, - False, - False, - ), - ( - "nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-Channel", - "channel", - None, - 8, - False, - False, - ), ( "nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder", "group", @@ -332,7 +283,7 @@ def check_model(model): ) def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): model, strategy, group, pack_factor, symmetric, has_g_idx = wNa16_args - with vllm_runner(model) as llm: + with vllm_runner(model, enforce_eager=True) as llm: def check_model(model): layer = model.model.layers[0] @@ -350,7 +301,7 @@ def check_model(model): llm.apply_model(check_model) - output = llm.generate_greedy("Hello my name is", max_tokens=20) + output = llm.generate_greedy("Hello my name is", max_tokens=4) assert output @@ -359,7 +310,7 @@ def check_model(model): ) def test_compressed_tensors_w4a16_marlin24(vllm_runner): model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t" - with vllm_runner(model_path) as llm: + with vllm_runner(model_path, enforce_eager=True) as llm: def check_model(model): layer = model.model.layers[0] @@ -372,13 +323,13 @@ def check_model(model): llm.apply_model(check_model) - output = llm.generate_greedy("Hello my name is", max_tokens=20) + output = llm.generate_greedy("Hello my name is", max_tokens=4) assert output def test_compressed_tensors_fp8(vllm_runner): model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test" - with vllm_runner(model_path) as llm: + with vllm_runner(model_path, enforce_eager=True) as llm: def check_model(model): layer = model.model.layers[0] @@ -401,21 +352,17 @@ def check_model(model): llm.apply_model(check_model) - output = llm.generate_greedy("Hello my name is", max_tokens=20) + output = llm.generate_greedy("Hello my name is", max_tokens=4) assert output -@pytest.mark.skipif( - not current_platform.is_kv_cache_dtype_supported("fp8", None), - reason="FP8 KV cache is not supported on this device.", -) @pytest.mark.skipif( not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." ) def test_compressed_tensors_kv_cache(vllm_runner): model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme" - with vllm_runner(model_path, kv_cache_dtype="fp8") as llm: - output = llm.generate_greedy("Hello world!", max_tokens=20) + with vllm_runner(model_path, enforce_eager=True, kv_cache_dtype="fp8") as llm: + output = llm.generate_greedy("Hello world!", max_tokens=4) assert output @@ -467,7 +414,7 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy, format="d ) def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4): model, weight_strategy, input_strategy = args_2of4 - with vllm_runner(model) as llm: + with vllm_runner(model, enforce_eager=True) as llm: def check_model(model): layer = model.model.layers[0] @@ -478,7 +425,7 @@ def check_model(model): llm.apply_model(check_model) - output = llm.generate_greedy("Hello my name is", max_tokens=20) + output = llm.generate_greedy("Hello my name is", max_tokens=4) print(output) assert output @@ -514,7 +461,7 @@ def check_model(model): ) def test_compressed_tensors_2of4_quant_fp8_compressed(vllm_runner, args_2of4): model, weight_strategy, input_strategy = args_2of4 - with vllm_runner(model) as llm: + with vllm_runner(model, enforce_eager=True) as llm: def check_model(model): layer = model.model.layers[0] @@ -530,7 +477,7 @@ def check_model(model): llm.apply_model(check_model) - output = llm.generate_greedy("Hello my name is", max_tokens=20) + output = llm.generate_greedy("Hello my name is", max_tokens=4) print(output) assert output @@ -566,7 +513,7 @@ def check_model(model): ) def test_compressed_tensors_2of4_quant_int8_compressed(vllm_runner, args_2of4): model, weight_strategy, input_strategy = args_2of4 - with vllm_runner(model) as llm: + with vllm_runner(model, enforce_eager=True) as llm: def check_model(model): layer = model.model.layers[0] @@ -582,7 +529,7 @@ def check_model(model): llm.apply_model(check_model) - output = llm.generate_greedy("Hello my name is", max_tokens=20) + output = llm.generate_greedy("Hello my name is", max_tokens=4) print(output) assert output @@ -613,7 +560,7 @@ def check_model(model): ) def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4): model, weight_strategy, input_strategy = args_2of4 - with vllm_runner(model) as llm: + with vllm_runner(model, enforce_eager=True) as llm: def check_model(model): layer = model.model.layers[0] @@ -624,7 +571,7 @@ def check_model(model): llm.apply_model(check_model) - output = llm.generate_greedy("Hello my name is", max_tokens=20) + output = llm.generate_greedy("Hello my name is", max_tokens=4) print(output) assert output @@ -639,7 +586,7 @@ def check_model(model): ) def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4): model = args_2of4 - with vllm_runner(model) as llm: + with vllm_runner(model, enforce_eager=True) as llm: def check_model(model): layer = model.model.layers[0] @@ -658,7 +605,7 @@ def check_model(model): llm.apply_model(check_model) - output = llm.generate_greedy("Hello my name is", max_tokens=20) + output = llm.generate_greedy("Hello my name is", max_tokens=4) print(output) assert output @@ -672,7 +619,7 @@ def check_model(model): ) def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4): model = args_2of4 - with vllm_runner(model) as llm: + with vllm_runner(model, enforce_eager=True) as llm: def check_model(model): layer = model.model.layers[0] @@ -691,7 +638,7 @@ def check_model(model): llm.apply_model(check_model) - output = llm.generate_greedy("Hello my name is", max_tokens=20) + output = llm.generate_greedy("Hello my name is", max_tokens=4) print(output) assert output @@ -699,7 +646,8 @@ def check_model(model): @pytest.mark.parametrize( "args", [ - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", CompressedTensorsW4A16Fp4), + # TODO: Enable once model is available again + # ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", CompressedTensorsW4A16Fp4), ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", CompressedTensorsW4A4Fp4), ], ) @@ -724,7 +672,7 @@ def check_model(model): assert qkv_proj.scheme.group_size == 16 llm.apply_model(check_model) - output = llm.generate_greedy("Hello my name is", max_tokens=20) + output = llm.generate_greedy("Hello my name is", max_tokens=4) print(output) assert output @@ -759,7 +707,7 @@ def check_model(model): assert proj.scheme.group_size == 128 llm.apply_model(check_model) - output = llm.generate_greedy("Hello my name is", max_tokens=20) + output = llm.generate_greedy("Hello my name is", max_tokens=4) print(output) assert output @@ -793,7 +741,7 @@ def test_compressed_tensors_transforms_perplexity( def test_compressed_tensors_fp8_block_enabled(vllm_runner): model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK" - with vllm_runner(model_path) as llm: + with vllm_runner(model_path, enforce_eager=True) as llm: fp8_dtype = current_platform.fp8_dtype() def check_model(model): @@ -817,5 +765,5 @@ def check_model(model): llm.apply_model(check_model) - output = llm.generate_greedy("Hello my name is", max_tokens=20) + output = llm.generate_greedy("Hello my name is", max_tokens=4) assert output diff --git a/tests/quantization/test_cpu_offload.py b/tests/quantization/test_cpu_offload.py index 25d1dc59f617..a3fb4a695347 100644 --- a/tests/quantization/test_cpu_offload.py +++ b/tests/quantization/test_cpu_offload.py @@ -16,13 +16,6 @@ reason="fp8 is not supported on this GPU type.", ) def test_cpu_offload_fp8(): - # Test quantization of an unquantized checkpoint - compare_two_settings( - "meta-llama/Llama-3.2-1B-Instruct", - ["--quantization", "fp8"], - ["--quantization", "fp8", "--cpu-offload-gb", "1"], - max_wait_seconds=480, - ) # Test loading a quantized checkpoint compare_two_settings( "neuralmagic/Qwen2-1.5B-Instruct-FP8", @@ -46,13 +39,6 @@ def test_cpu_offload_gptq(monkeypatch): ["--cpu-offload-gb", "1"], max_wait_seconds=480, ) - # Test GPTQ - compare_two_settings( - "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4", - ["--quantization", "gptq"], - ["--quantization", "gptq", "--cpu-offload-gb", "1"], - max_wait_seconds=480, - ) @pytest.mark.skipif( @@ -69,13 +55,6 @@ def test_cpu_offload_awq(monkeypatch): ["--cpu-offload-gb", "1"], max_wait_seconds=480, ) - # Test AWQ - compare_two_settings( - "Qwen/Qwen2-1.5B-Instruct-AWQ", - ["--quantization", "awq"], - ["--quantization", "awq", "--cpu-offload-gb", "1"], - max_wait_seconds=480, - ) @pytest.mark.skipif( @@ -92,17 +71,3 @@ def test_cpu_offload_compressed_tensors(monkeypatch): ["--cpu-offload-gb", "1"], max_wait_seconds=480, ) - # Test w4a16_marlin24 - compare_two_settings( - "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", - [], - ["--cpu-offload-gb", "1"], - max_wait_seconds=480, - ) - # Test w8a8 - compare_two_settings( - "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", - [], - ["--cpu-offload-gb", "1"], - max_wait_seconds=480, - ) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 6b9a33059815..7f863a169d5f 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -18,7 +18,6 @@ MODELS = [ "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", - "nm-testing/Phi-3-mini-128k-instruct-FP8", "nm-testing/Qwen2-0.5B-Instruct-FP8-SkipQKV", ] @@ -49,8 +48,6 @@ def test_model_load_and_run( KV_CACHE_MODELS = [ - # Deprecated AutoFP8 format using .kv_scale - "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", # AutoFP8 format using separate .k_scale and .v_scale "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V", ] diff --git a/tests/quantization/test_gptq_dynamic.py b/tests/quantization/test_gptq_dynamic.py index c71f4b815611..37fe2dd3243a 100644 --- a/tests/quantization/test_gptq_dynamic.py +++ b/tests/quantization/test_gptq_dynamic.py @@ -40,7 +40,9 @@ def test_gptq_with_dynamic( GPTQMarlinLinearMethod if use_marlin_kernel else (GPTQLinearMethod) ) - with vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) as llm: + with vllm_runner( + model_id, dtype=torch.float16, max_model_len=2048, enforce_eager=True + ) as llm: def check_model(model): for name, submodule in model.named_modules(): diff --git a/tests/quantization/test_gptq_v2.py b/tests/quantization/test_gptq_v2.py new file mode 100644 index 000000000000..dbafa2e8e7d1 --- /dev/null +++ b/tests/quantization/test_gptq_v2.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests whether vllm correctly load and run gptq_v2 format checkpoints. + +Run `pytest tests/quantization/test_gptq_v2.py --forked`. +""" + +import pytest +import torch +from transformers import AutoTokenizer + +from vllm import SamplingParams +from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod + +# A dummy small model quantized by GPTQModel, stored in GPTQ v2 format +MODELS = ["XXXXyu/Qwen3-1.7B-w2g64-gptq_v2"] + +# Generate multiple sequences for testing, because an 1.7B 2-bit model +# cannot always generate normal texts. +N_SEQ = 5 + + +@pytest.mark.parametrize("model_id", MODELS) +def test_model_load(vllm_runner, model_id, monkeypatch): + # `LLM.apply_model` requires pickling a function. + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + + # Only check the default GPTQ linear method (used for 2/3-bit models). + # 4/8-bit linear methods like Marlin already support gptq_v2. + linear_method_cls = GPTQLinearMethod + + with vllm_runner(model_id, dtype=torch.float16, max_model_len=512) as llm: + + def check_model(model_id): + for name, submodule in model_id.named_modules(): + # Could check more modules if necessary + if name == "model_id.layers.0.self_attn.qkv_proj": + assert isinstance(submodule.quant_method, linear_method_cls) + + config = submodule.quant_method.quant_config + assert config.checkpoint_format == "gptq_v2" + assert submodule.quant_method.use_v2_format + + # Just break since currently we only check 1 module + break + + # Check if gptq_v2 format is correctly loaded + llm.apply_model(check_model) + + +@pytest.mark.parametrize("model_id", MODELS) +def test_model_inference(vllm_runner, model_id): + # Prepare prompt to test the model's generation result. + prompt = "What is the meaning of life?" + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ] + tokenizer = AutoTokenizer.from_pretrained(model_id) + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, # If thinking model, set it to false + ) + sampling_params = SamplingParams( + n=N_SEQ, + max_tokens=128, + temperature=0.7, + top_p=0.8, + top_k=20, + min_p=0, + presence_penalty=2, + ) + + with vllm_runner(model_id, dtype=torch.float16, max_model_len=512) as llm: + # Generate a response to verify inference correctness + output = llm.generate(text, sampling_params) + + # Make sure the output exists + assert output + assert output[0][1] + assert len(output[0][1]) == N_SEQ + + def has_normal_char_distribution(texts, min_len): + for text in texts: + # Response too short + if len(text) < min_len: + return False + + # Basic ratio checks + letters = sum(c.isalpha() for c in text) + spaces = sum(c.isspace() for c in text) + total = len(text) + + letter_ratio = letters / total + space_ratio = spaces / total + + # At least 1 normal text should exist within output sequences + # Normal text should be mostly letters with reasonable spacing + # Some magic numbers, could be adjusted + if 0.5 <= letter_ratio <= 0.9 and 0.01 <= space_ratio <= 0.3: + return True + # No sequence contains normal text, output might be broken + return False + + # Apply some simple checks for giberish output + # Print the output sequences if failed + assert has_normal_char_distribution(output[0][1], 5), output[0][1] diff --git a/tests/quantization/test_lm_head.py b/tests/quantization/test_lm_head.py index bae8b7f7d535..f009a4cfb870 100644 --- a/tests/quantization/test_lm_head.py +++ b/tests/quantization/test_lm_head.py @@ -31,7 +31,9 @@ def test_lm_head( ) -> None: # `LLM.apply_model` requires pickling a function. monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - with vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) as vllm_model: + with vllm_runner( + model_id, dtype=torch.float16, max_model_len=2048, enforce_eager=True + ) as vllm_model: def check_model(model): lm_head_layer = model.lm_head diff --git a/tests/quantization/test_quark.py b/tests/quantization/test_quark.py index 1e65d9a995ce..0af27aff9359 100644 --- a/tests/quantization/test_quark.py +++ b/tests/quantization/test_quark.py @@ -11,7 +11,6 @@ import os from dataclasses import dataclass from importlib.util import find_spec -from typing import Optional import huggingface_hub import lm_eval @@ -57,7 +56,10 @@ def enable_pickle(monkeypatch): def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp): model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test" with vllm_runner( - model_path, kv_cache_dtype=kv_cache_dtype, tensor_parallel_size=tp + model_path, + enforce_eager=True, + kv_cache_dtype=kv_cache_dtype, + tensor_parallel_size=tp, ) as llm: def check_model(model): @@ -75,14 +77,14 @@ def check_model(model): llm.apply_model(check_model) - output = llm.generate_greedy("Hello my name is", max_tokens=20) + output = llm.generate_greedy("Hello my name is", max_tokens=4) assert output @pytest.mark.parametrize("tp", [1]) def test_quark_fp8_w_per_channel_a_per_token(vllm_runner, tp): model_path = "amd/Qwen2.5-1.5B-Instruct-ptpc-Quark-ts" - with vllm_runner(model_path, tensor_parallel_size=tp) as llm: + with vllm_runner(model_path, enforce_eager=True, tensor_parallel_size=tp) as llm: def check_model(model): layer = model.model.layers[0] @@ -99,14 +101,14 @@ def check_model(model): llm.apply_model(check_model) - output = llm.generate_greedy("Hello my name is", max_tokens=20) + output = llm.generate_greedy("Hello my name is", max_tokens=4) assert output @pytest.mark.parametrize("tp", [1]) def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp): model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test" - with vllm_runner(model_path, tensor_parallel_size=tp) as llm: + with vllm_runner(model_path, enforce_eager=True, tensor_parallel_size=tp) as llm: def check_model(model): layer = model.model.layers[0] @@ -118,7 +120,7 @@ def check_model(model): llm.apply_model(check_model) - output = llm.generate_greedy("Hello my name is", max_tokens=20) + output = llm.generate_greedy("Hello my name is", max_tokens=4) assert output @@ -156,8 +158,8 @@ class AccuracyTestConfig: def get_model_args( self, tp_size: int, - model_max_len: Optional[int] = None, - kwargs: Optional[dict] = None, + model_max_len: int | None = None, + kwargs: dict | None = None, ) -> dict: if kwargs is None: kwargs = {} diff --git a/tests/quantization/test_register_quantization_config.py b/tests/quantization/test_register_quantization_config.py index b70c2ee7fe2e..aeef4c2fd8a7 100644 --- a/tests/quantization/test_register_quantization_config.py +++ b/tests/quantization/test_register_quantization_config.py @@ -7,7 +7,7 @@ Run `pytest tests/quantization/test_register_quantization_config.py`. """ -from typing import Any, Optional +from typing import Any import pytest import torch @@ -37,10 +37,10 @@ def __init__(self, num_bits: int = 8) -> None: def apply( self, - layer: "torch.nn.Module", - x: "torch.Tensor", - bias: Optional["torch.Tensor"] = None, - ) -> "torch.Tensor": + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: """Perform fake quantization before the linear layer.""" # Calculate the scales dynamically @@ -72,7 +72,7 @@ def get_name(self) -> QuantizationMethods: """Name of the quantization method.""" return "custom_quant" - def get_supported_act_dtypes(self) -> list["torch.dtype"]: + def get_supported_act_dtypes(self) -> list[torch.dtype]: """List of supported activation dtypes.""" return [torch.float16, torch.bfloat16] @@ -92,8 +92,8 @@ def from_config(cls, config: dict[str, Any]) -> "CustomQuantConfig": return CustomQuantConfig(num_bits=config.get("num_bits", 8)) def get_quant_method( - self, layer: "torch.nn.Module", prefix: str - ) -> Optional["FakeQuantLinearMethod"]: + self, layer: torch.nn.Module, prefix: str + ) -> FakeQuantLinearMethod | None: """Get the quantize method to use for the quantized layer.""" if isinstance(layer, LinearBase): return FakeQuantLinearMethod(num_bits=self.num_bits) diff --git a/tests/quantization/test_rtn.py b/tests/quantization/test_rtn.py index 370625ed3479..195f1fbbdfc0 100644 --- a/tests/quantization/test_rtn.py +++ b/tests/quantization/test_rtn.py @@ -10,7 +10,6 @@ from tests.quantization.utils import is_quant_method_supported MODELS = [ - "microsoft/Phi-3-mini-4k-instruct", # dense model "ai21labs/Jamba-tiny-dev", # MoE model ] @@ -30,5 +29,7 @@ def test_model_rtn_startup( dtype: str, max_tokens: int, ) -> None: - with vllm_runner(model, dtype=dtype, quantization="rtn") as vllm_model: + with vllm_runner( + model, enforce_eager=True, dtype=dtype, quantization="rtn" + ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index d1cf7e163596..cab198a2a15e 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -19,7 +19,7 @@ def test_pre_quantized_model(vllm_runner): dtype="bfloat16", enforce_eager=True, ) as llm: - output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + output = llm.generate_greedy(["The capital of France is"], max_tokens=4) assert output @@ -39,8 +39,9 @@ def test_opt_125m_int8wo_model_loading_with_params(vllm_runner, pt_load_map_loca quantization="torchao", dtype="bfloat16", pt_load_map_location=pt_load_map_location, + enforce_eager=True, ) as llm: - output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + output = llm.generate_greedy(["The capital of France is"], max_tokens=4) assert output @@ -54,8 +55,9 @@ def test_opt_125m_int4wo_model_per_module_quant(vllm_runner): quantization="torchao", dtype="bfloat16", pt_load_map_location="cuda:0", + enforce_eager=True, ) as llm: - output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + output = llm.generate_greedy(["The capital of France is"], max_tokens=4) assert output @@ -69,8 +71,9 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner): quantization="torchao", dtype="bfloat16", pt_load_map_location="cuda:0", + enforce_eager=True, ) as llm: - output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + output = llm.generate_greedy(["The capital of France is"], max_tokens=4) assert output @@ -90,7 +93,7 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner): dtype="bfloat16", pt_load_map_location="cuda:0", ) as llm: - output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + output = llm.generate_greedy(["The capital of France is"], max_tokens=4) assert output @@ -122,8 +125,9 @@ def test_on_the_fly_quant_config_dict_json(vllm_runner): pt_load_map_location="cuda:0", quantization="torchao", hf_overrides=hf_overrides, + enforce_eager=True, ) as llm: - output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + output = llm.generate_greedy(["The capital of France is"], max_tokens=4) assert output @@ -156,8 +160,9 @@ def test_on_the_fly_quant_config_file(vllm_runner): pt_load_map_location="cuda:0", quantization="torchao", hf_overrides=hf_overrides, + enforce_eager=True, ) as llm: - output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + output = llm.generate_greedy(["The capital of France is"], max_tokens=4) assert output @@ -228,7 +233,24 @@ def test_opt_125m_float8_weight_only_safetensors_model_loading_with_params(vllm_ "torchao-testing/opt-125m-Float8WeightOnlyConfig-v2-0.14.0.dev-safetensors" ) with vllm_runner(model_name=model_name, dtype="bfloat16") as llm: - output = llm.generate_greedy(["The capital of France is"], max_tokens=32) + output = llm.generate_greedy(["The capital of France is"], max_tokens=4) + + assert output + + +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +@pytest.mark.skip( + reason="since torchao nightly is only compatible with torch nightly" + "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip " + "torchao tests that requires newer versions (0.14.0.dev+) for now" +) +def test_opt_125m_module_fqn_to_config_regex_model(vllm_runner): + torch._dynamo.reset() + model_name = "torchao-testing/opt-125m-ModuleFqnToConfig-v1-regex-0.14.0.dev" + with vllm_runner( + model_name=model_name, dtype="bfloat16", pt_load_map_location="cuda:0" + ) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=4) assert output diff --git a/tests/reasoning/test_deepseekv3_reasoning_parser.py b/tests/reasoning/test_deepseekv3_reasoning_parser.py new file mode 100644 index 000000000000..3d12f3e5b30e --- /dev/null +++ b/tests/reasoning/test_deepseekv3_reasoning_parser.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage +from vllm.reasoning import ( + DeepSeekR1ReasoningParser, + DeepSeekV3ReasoningParser, + IdentityReasoningParser, +) + +REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-V3.1" + + +@pytest.fixture(scope="module") +def tokenizer(): + return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + + +@pytest.mark.parametrize( + "thinking,expected_parser_type", + [ + (True, DeepSeekR1ReasoningParser), + (False, IdentityReasoningParser), + ], +) +def test_parser_selection(tokenizer, thinking, expected_parser_type): + parser = DeepSeekV3ReasoningParser( + tokenizer, chat_template_kwargs={"thinking": thinking} + ) + + assert isinstance(parser._parser, expected_parser_type) + + +def test_identity_reasoning_parser_basic(tokenizer): + parser = IdentityReasoningParser(tokenizer) + + # Test is_reasoning_end always returns True + input_text = "This is some output" + input_tokens = tokenizer.tokenize(input_text) + input_ids = tokenizer.convert_tokens_to_ids(input_tokens) + assert parser.is_reasoning_end(input_ids) is True + + # Test extract_content_ids returns all input_ids + assert parser.extract_content_ids(input_ids) == input_ids + + # Test extract_reasoning_content returns (None, model_output) + request = ChatCompletionRequest(model="test-model", messages=[], temperature=1.0) + reasoning, content = parser.extract_reasoning_content(input_text, request) + assert reasoning is None + assert content == input_text + + # Test extract_reasoning_content_streaming returns DeltaMessage or None + result = parser.extract_reasoning_content_streaming( + previous_text="", + current_text="Hello world", + delta_text="Hello world", + previous_token_ids=[], + current_token_ids=input_ids, + delta_token_ids=input_ids, + ) + assert isinstance(result, DeltaMessage) + assert result.content == "Hello world" + + # If delta_text is empty, should return None + result_none = parser.extract_reasoning_content_streaming( + previous_text="Hello world", + current_text="Hello world", + delta_text="", + previous_token_ids=input_ids, + current_token_ids=input_ids, + delta_token_ids=[], + ) + assert result_none is None diff --git a/tests/reasoning/test_ernie45_reasoning_parser.py b/tests/reasoning/test_ernie45_reasoning_parser.py new file mode 100644 index 000000000000..344478013e6b --- /dev/null +++ b/tests/reasoning/test_ernie45_reasoning_parser.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +parser_name = "ernie45" + +REASONING_MODEL_NAME = "baidu/ERNIE-4.5-21B-A3B-Thinking" + + +@pytest.fixture(scope="module") +def ernie45_tokenizer(): + return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) + + +# 带 </think>,非stream +WITH_THINK = { + "output": "abc</think>def", + "reasoning_content": "abc", + "content": "def", +} +# 带 </think>,stream +WITH_THINK_STREAM = { + "output": "abc</think>def", + "reasoning_content": "abc", + "content": "def", +} +# without </think>, all is reasoning_content +WITHOUT_THINK = { + "output": "abc", + "reasoning_content": "abc", + "content": None, +} +# without </think>, all is reasoning_content +WITHOUT_THINK_STREAM = { + "output": "abc", + "reasoning_content": "abc", + "content": None, +} + +COMPLETE_REASONING = { + "output": "abc</think>", + "reasoning_content": "abc", + "content": None, +} +MULTILINE_REASONING = { + "output": "abc\nABC</think>def\nDEF", + "reasoning_content": "abc\nABC", + "content": "def\nDEF", +} + +TEST_CASES = [ + pytest.param( + False, + WITH_THINK, + id="with_think", + ), + pytest.param( + True, + WITH_THINK_STREAM, + id="with_think_stream", + ), + pytest.param( + False, + WITHOUT_THINK, + id="without_think", + ), + pytest.param( + True, + WITHOUT_THINK_STREAM, + id="without_think_stream", + ), + pytest.param( + False, + COMPLETE_REASONING, + id="complete_reasoning", + ), + pytest.param( + True, + COMPLETE_REASONING, + id="complete_reasoning_stream", + ), + pytest.param( + False, + MULTILINE_REASONING, + id="multiline_reasoning", + ), + pytest.param( + True, + MULTILINE_REASONING, + id="multiline_reasoning_stream", + ), +] + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_reasoning( + streaming: bool, + param_dict: dict, + ernie45_tokenizer, +): + output = ernie45_tokenizer.tokenize(param_dict["output"]) + output_tokens: list[str] = [] + for token in output: + one_token = ernie45_tokenizer.convert_tokens_to_string([token]) + if one_token: + output_tokens.append(one_token) + + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + ernie45_tokenizer + ) + + reasoning, content = run_reasoning_extraction( + parser, output_tokens, streaming=streaming + ) + + print() + + assert reasoning == param_dict["reasoning_content"] + assert content == param_dict["content"] diff --git a/tests/reasoning/test_mistral_reasoning_parser.py b/tests/reasoning/test_mistral_reasoning_parser.py index 96107c0c1193..ff7f94b40ee1 100644 --- a/tests/reasoning/test_mistral_reasoning_parser.py +++ b/tests/reasoning/test_mistral_reasoning_parser.py @@ -2,8 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from mistral_common.tokens.tokenizers.base import SpecialTokens -from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer from tests.reasoning.utils import run_reasoning_extraction_mistral from vllm.reasoning import ReasoningParser, ReasoningParserManager @@ -14,33 +12,9 @@ @pytest.fixture(scope="module") def mistral_tokenizer(): - # TODO(Julien): upon model release change to a tokenizer already configured. - # ================================================================= mistral_tokenizer = MistralTokenizer.from_pretrained( - "mistralai/Devstral-Small-2507" + "mistralai/Magistral-Small-2509" ) - assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer) - # Add think special tokens to the tokenizer - mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo( - rank=35, is_control=True, token_str=SpecialTokens.begin_think.value - ) - mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo( - rank=36, is_control=True, token_str=SpecialTokens.end_think.value - ) - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = { - k: v - for k, v in mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items() - if v not in {35, 36} - } - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ - SpecialTokens.begin_think.value - ] = 35 - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ - SpecialTokens.end_think.value - ] = 36 - mistral_tokenizer.instruct.BEGIN_THINK = 35 - mistral_tokenizer.instruct.END_THINK = 36 - # ================================================================= return mistral_tokenizer diff --git a/tests/reasoning/utils.py b/tests/reasoning/utils.py index 788136e99681..ccd4ff8dd263 100644 --- a/tests/reasoning/utils.py +++ b/tests/reasoning/utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.reasoning import ReasoningParser @@ -34,9 +33,9 @@ def append_delta(self, delta: DeltaMessage): def run_reasoning_extraction( reasoning_parser: ReasoningParser, model_output: list[str], - request: Union[ChatCompletionRequest, None] = None, + request: ChatCompletionRequest | None = None, streaming: bool = False, -) -> tuple[Optional[str], Optional[str]]: +) -> tuple[str | None, str | None]: if streaming: reconstructor = run_reasoning_extraction_streaming( reasoning_parser, @@ -57,9 +56,9 @@ def run_reasoning_extraction( def run_reasoning_extraction_mistral( reasoning_parser: ReasoningParser, model_output: list[int], - request: Union[ChatCompletionRequest, None] = None, + request: ChatCompletionRequest | None = None, streaming: bool = False, -) -> tuple[Optional[str], Optional[str]]: +) -> tuple[str | None, str | None]: assert isinstance(reasoning_parser.model_tokenizer, MistralTokenizer), type( reasoning_parser.model_tokenizer ) @@ -86,8 +85,8 @@ def run_reasoning_extraction_mistral( def run_reasoning_extraction_nonstreaming( reasoning_parser: ReasoningParser, model_output: list[str], - request: Union[ChatCompletionRequest, None] = None, -) -> tuple[Optional[str], Optional[str]]: + request: ChatCompletionRequest | None = None, +) -> tuple[str | None, str | None]: request = request or ChatCompletionRequest(messages=[], model="test-model") return reasoning_parser.extract_reasoning_content( model_output="".join(model_output), request=request @@ -97,7 +96,7 @@ def run_reasoning_extraction_nonstreaming( def run_reasoning_extraction_streaming( reasoning_parser: ReasoningParser, model_deltas: list[str], - request: Union[ChatCompletionRequest, None] = None, + request: ChatCompletionRequest | None = None, ) -> StreamingReasoningReconstructor: request = request or ChatCompletionRequest(messages=[], model="test-model") reconstructor = StreamingReasoningReconstructor() @@ -129,7 +128,7 @@ def run_reasoning_extraction_streaming( def run_reasoning_extraction_streaming_mistral( reasoning_parser: ReasoningParser, model_deltas: list[int], - request: Union[ChatCompletionRequest, None] = None, + request: ChatCompletionRequest | None = None, ) -> StreamingReasoningReconstructor: assert isinstance(reasoning_parser.model_tokenizer, MistralTokenizer), type( reasoning_parser.model_tokenizer diff --git a/tests/samplers/test_no_bad_words.py b/tests/samplers/test_no_bad_words.py index fa0ca48f9bd9..74047d2f0355 100644 --- a/tests/samplers/test_no_bad_words.py +++ b/tests/samplers/test_no_bad_words.py @@ -6,8 +6,6 @@ """ -from typing import Optional - from transformers import AutoTokenizer from vllm import LLM, SamplingParams @@ -18,7 +16,7 @@ def _generate( prompt: str, num_prompt_tokens: int, temperature: float = 0, - bad_words: Optional[list[str]] = None, + bad_words: list[str] | None = None, ) -> list[int]: sampling_params = SamplingParams( temperature=temperature, @@ -37,15 +35,13 @@ def _generate( class TestOneTokenBadWord: - MODEL = "TheBloke/Llama-2-7B-fp16" + MODEL = "hmellor/tiny-random-LlamaForCausalLM" - PROMPT = "Hi! How are" - TARGET_TOKEN = "you" + PROMPT = "How old are " + TARGET_TOKEN = "mn" def setup_method(self, method): - self.tokenizer = AutoTokenizer.from_pretrained( - self.MODEL, add_prefix_space=True - ) + self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL) self.num_prompt_tokens = len(self._encode(self.PROMPT)) self.target_token_id = self._encode( @@ -60,7 +56,7 @@ def test_one_token_bad_word(self, vllm_runner): output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN]) assert self.target_token_id not in output_token_ids - def _generate(self, llm: LLM, bad_words: Optional[list[str]] = None) -> list[int]: + def _generate(self, llm: LLM, bad_words: list[str] | None = None) -> list[int]: return _generate( llm=llm, prompt=self.PROMPT, @@ -155,7 +151,7 @@ def test_two_token_bad_word(self, vllm_runner): self.neighbour_token_id2 in output_token_ids ) - def _generate(self, llm: LLM, bad_words: Optional[list[str]] = None) -> list[int]: + def _generate(self, llm: LLM, bad_words: list[str] | None = None) -> list[int]: return _generate( llm=llm, prompt=self.PROMPT, diff --git a/tests/speculative_decoding/speculators/test_eagle3.py b/tests/speculative_decoding/speculators/test_eagle3.py index 5ce6e1593b5c..19ba32d8dee4 100644 --- a/tests/speculative_decoding/speculators/test_eagle3.py +++ b/tests/speculative_decoding/speculators/test_eagle3.py @@ -22,6 +22,10 @@ "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16", id="qwen3-eagle3-speculator-w4a16-verifier", ), + pytest.param( + "nm-testing/random-weights-llama3.1.8b-2layer-eagle3", + id="llama3-eagl3-multiple-layers", + ), ], ) def test_eagle3_speculators_model( diff --git a/tests/test_envs.py b/tests/test_envs.py index 62d529c36360..023767505f10 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -6,7 +6,54 @@ import pytest -from vllm.envs import env_list_with_choices, env_with_choices +import vllm.envs as envs +from vllm.envs import ( + enable_envs_cache, + env_list_with_choices, + env_with_choices, + environment_variables, +) + + +def test_getattr_without_cache(monkeypatch: pytest.MonkeyPatch): + assert envs.VLLM_HOST_IP == "" + assert envs.VLLM_PORT is None + monkeypatch.setenv("VLLM_HOST_IP", "1.1.1.1") + monkeypatch.setenv("VLLM_PORT", "1234") + assert envs.VLLM_HOST_IP == "1.1.1.1" + assert envs.VLLM_PORT == 1234 + # __getattr__ is not decorated with functools.cache + assert not hasattr(envs.__getattr__, "cache_info") + + +def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_HOST_IP", "1.1.1.1") + monkeypatch.setenv("VLLM_PORT", "1234") + # __getattr__ is not decorated with functools.cache + assert not hasattr(envs.__getattr__, "cache_info") + + # Enable envs cache and ignore ongoing environment changes + enable_envs_cache() + + # __getattr__ is not decorated with functools.cache + assert hasattr(envs.__getattr__, "cache_info") + start_hits = envs.__getattr__.cache_info().hits + + # 2 more hits due to VLLM_HOST_IP and VLLM_PORT accesses + assert envs.VLLM_HOST_IP == "1.1.1.1" + assert envs.VLLM_PORT == 1234 + assert envs.__getattr__.cache_info().hits == start_hits + 2 + + # All environment variables are cached + for environment_variable in environment_variables: + envs.__getattr__(environment_variable) + assert envs.__getattr__.cache_info().hits == start_hits + 2 + len( + environment_variables + ) + + # Reset envs.__getattr__ back to none-cached version to + # avoid affecting other tests + envs.__getattr__ = envs.__getattr__.__wrapped__ class TestEnvWithChoices: diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 77379cc8de90..66c4ff4135e2 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -7,6 +7,7 @@ from vllm.inputs import zip_enc_dec_prompts from vllm.inputs.parse import parse_raw_prompts from vllm.inputs.preprocess import InputPreprocessor +from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs pytestmark = pytest.mark.cpu_test @@ -106,7 +107,8 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs): ) def test_preprocessor_text_no_mm_inputs(model_id, prompt): model_config = ModelConfig(model=model_id) - input_preprocessor = InputPreprocessor(model_config) + tokenizer = init_tokenizer_from_configs(model_config) + input_preprocessor = InputPreprocessor(model_config, tokenizer) with pytest.raises(ValueError, match="does not support multimodal inputs"): input_preprocessor.preprocess(prompt) @@ -127,11 +129,48 @@ def test_preprocessor_text_no_mm_inputs(model_id, prompt): ) def test_preprocessor_always_mm_code_path(model_id, prompt): model_config = ModelConfig(model=model_id) - input_preprocessor = InputPreprocessor(model_config) - tokenizer = input_preprocessor.tokenizer + tokenizer = init_tokenizer_from_configs(model_config) + input_preprocessor = InputPreprocessor(model_config, tokenizer) # HF processor adds sep token sep_token_id = tokenizer.vocab[tokenizer.sep_token] processed_inputs = input_preprocessor.preprocess(prompt) assert sep_token_id in processed_inputs["prompt_token_ids"] + + +def _get_bos_prefixed_prompt_or_skip(tokenizer): + bos_token = getattr(tokenizer, "bos_token", None) + if not bos_token or not isinstance(bos_token, str): + pytest.skip("Tokenizer has no string bos_token to test BOS handling.") + return f"{bos_token} Hello world" + + +@pytest.mark.parametrize( + "explicit_add_special", + [True, None], +) +def test_double_bos_token(monkeypatch, explicit_add_special): + model_config = ModelConfig(model="facebook/opt-125m") + input_preprocessor = InputPreprocessor(model_config) + + tokenizer = input_preprocessor.get_tokenizer() + prompt = _get_bos_prefixed_prompt_or_skip(tokenizer) + + captured: dict[str, object] = {} + + def fake_encode(text, **kwargs): + captured["kwargs"] = dict(kwargs) + # dummy + return [101, 102, 103] + + monkeypatch.setattr(tokenizer, "encode", fake_encode, raising=True) + + if explicit_add_special is True: + _ = input_preprocessor._tokenize_prompt( + prompt, tokenization_kwargs={"add_special_tokens": True} + ) + assert captured["kwargs"].get("add_special_tokens") is True + else: + _ = input_preprocessor._tokenize_prompt(prompt) + assert captured["kwargs"].get("add_special_tokens") is False diff --git a/tests/test_logger.py b/tests/test_logger.py index ec368d4897b5..01672358902f 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -501,3 +501,49 @@ def test_streaming_complete_logs_full_text_content(): assert call_args[1] == "test-streaming-full-text" assert call_args[2] == " (streaming complete)" assert call_args[5] == "streaming_complete" + + +# Add vllm prefix to make sure logs go through the vllm logger +test_logger = init_logger("vllm.test_logger") + + +def mp_function(**kwargs): + # This function runs in a subprocess + + test_logger.warning("This is a subprocess: %s", kwargs.get("a")) + test_logger.error("This is a subprocess error.") + test_logger.debug("This is a subprocess debug message: %s.", kwargs.get("b")) + + +def test_caplog_mp_fork(caplog_vllm, caplog_mp_fork): + with caplog_vllm.at_level(logging.DEBUG), caplog_mp_fork(): + import multiprocessing + + ctx = multiprocessing.get_context("fork") + p = ctx.Process( + target=mp_function, + name=f"SubProcess{1}", + kwargs={"a": "AAAA", "b": "BBBBB"}, + ) + p.start() + p.join() + + assert "AAAA" in caplog_vllm.text + assert "BBBBB" in caplog_vllm.text + + +def test_caplog_mp_spawn(caplog_mp_spawn): + with caplog_mp_spawn(logging.DEBUG) as log_holder: + import multiprocessing + + ctx = multiprocessing.get_context("spawn") + p = ctx.Process( + target=mp_function, + name=f"SubProcess{1}", + kwargs={"a": "AAAA", "b": "BBBBB"}, + ) + p.start() + p.join() + + assert "AAAA" in log_holder.text + assert "BBBBB" in log_holder.text diff --git a/tests/test_pooling_params.py b/tests/test_pooling_params.py index e3561ac3a577..e73d7efc1483 100644 --- a/tests/test_pooling_params.py +++ b/tests/test_pooling_params.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass + import pytest from tests.models.utils import EmbedModelInfo from vllm import PoolingParams -from vllm.config import ModelConfig +from vllm.config import ModelConfig, PoolerConfig EMBEDDING_MODELS = [ EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False), @@ -15,6 +17,15 @@ ), ] +classify_parameters = ["activation"] +embed_parameters = ["dimensions", "normalize"] +step_pooling_parameters = ["step_tag_id", "returned_token_ids"] + + +@dataclass() +class MockModelConfig: + pooler_config: PoolerConfig + def test_task(): pooling_params = PoolingParams() @@ -24,25 +35,27 @@ def test_task(): pooling_params.verify(task="score") with pytest.raises(ValueError): - pooling_params.verify(task="encode") + pooling_params.verify(task="classify") def test_embed(): task = "embed" + model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS")) + pooling_params = PoolingParams(normalize=None) - pooling_params.verify(task=task) + pooling_params.verify(task=task, model_config=model_config) pooling_params = PoolingParams(normalize=True) - pooling_params.verify(task=task) + pooling_params.verify(task=task, model_config=model_config) pooling_params = PoolingParams(normalize=False) - pooling_params.verify(task=task) + pooling_params.verify(task=task, model_config=model_config) - invalid_parameters = ["activation", "softmax"] + invalid_parameters = classify_parameters + step_pooling_parameters for p in invalid_parameters: with pytest.raises(ValueError): pooling_params = PoolingParams(**{p: True}) - pooling_params.verify(task=task) + pooling_params.verify(task=task, model_config=model_config) @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) @@ -73,35 +86,71 @@ def test_embed_dimensions(model_info: EmbedModelInfo): @pytest.mark.parametrize("task", ["score", "classify"]) def test_classify(task): + model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS")) + pooling_params = PoolingParams(activation=None) - pooling_params.verify(task=task) + pooling_params.verify(task=task, model_config=model_config) pooling_params = PoolingParams(activation=True) - pooling_params.verify(task=task) + pooling_params.verify(task=task, model_config=model_config) pooling_params = PoolingParams(activation=False) - pooling_params.verify(task=task) + pooling_params.verify(task=task, model_config=model_config) + + invalid_parameters = embed_parameters + step_pooling_parameters + for p in invalid_parameters: + with pytest.raises(ValueError): + pooling_params = PoolingParams(**{p: True}) + pooling_params.verify(task=task, model_config=model_config) + + +@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"]) +def test_token_embed(pooling_type: str): + task = "token_embed" + model_config = MockModelConfig( + pooler_config=PoolerConfig(pooling_type=pooling_type) + ) + + pooling_params = PoolingParams(normalize=None) + pooling_params.verify(task=task, model_config=model_config) + + pooling_params = PoolingParams(normalize=True) + pooling_params.verify(task=task, model_config=model_config) + + pooling_params = PoolingParams(normalize=False) + pooling_params.verify(task=task, model_config=model_config) + + invalid_parameters = classify_parameters + if pooling_type != "STEP": + invalid_parameters = classify_parameters + step_pooling_parameters - invalid_parameters = ["dimensions", "normalize", "softmax"] for p in invalid_parameters: with pytest.raises(ValueError): pooling_params = PoolingParams(**{p: True}) - pooling_params.verify(task=task) + pooling_params.verify(task=task, model_config=model_config) -def test_encode(): - task = "encode" - pooling_params = PoolingParams(softmax=None) - pooling_params.verify(task=task) +@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"]) +def test_token_classify(pooling_type: str): + task = "token_classify" + model_config = MockModelConfig( + pooler_config=PoolerConfig(pooling_type=pooling_type) + ) - pooling_params = PoolingParams(softmax=True) - pooling_params.verify(task=task) + pooling_params = PoolingParams(activation=None) + pooling_params.verify(task=task, model_config=model_config) + + pooling_params = PoolingParams(activation=True) + pooling_params.verify(task=task, model_config=model_config) + + pooling_params = PoolingParams(activation=False) + pooling_params.verify(task=task, model_config=model_config) - pooling_params = PoolingParams(softmax=False) - pooling_params.verify(task=task) + invalid_parameters = embed_parameters + if pooling_type != "STEP": + invalid_parameters = embed_parameters + step_pooling_parameters - invalid_parameters = ["dimensions", "normalize", "activation"] for p in invalid_parameters: with pytest.raises(ValueError): pooling_params = PoolingParams(**{p: True}) - pooling_params.verify(task=task) + pooling_params.verify(task=task, model_config=model_config) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 14dcab7707d4..f4b43a21daaa 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Generator -from typing import Any, Optional +from typing import Any import pytest from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast @@ -52,7 +52,7 @@ def _run_incremental_decode( skip_special_tokens: bool, starting_index: int, spaces_between_special_tokens: bool = True, - fast: Optional[bool] = None, + fast: bool | None = None, ): prompt_token_ids = all_input_ids[:starting_index] diff --git a/tests/tokenization/test_mistral_tokenizer.py b/tests/tokenization/test_mistral_tokenizer.py index a034188387d0..ebf107217c3c 100644 --- a/tests/tokenization/test_mistral_tokenizer.py +++ b/tests/tokenization/test_mistral_tokenizer.py @@ -1,27 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + import pytest -from mistral_common.protocol.instruct.messages import ( - AssistantMessage, - ToolMessage, - UserMessage, -) -from mistral_common.protocol.instruct.request import ChatCompletionRequest -from mistral_common.protocol.instruct.tool_calls import ( - Function, - FunctionCall, - Tool, - ToolCall, -) +from mistral_common.exceptions import InvalidMessageStructureException +from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from vllm.transformers_utils.tokenizers.mistral import ( - make_mistral_chat_completion_request, + MistralTokenizer, + _prepare_apply_chat_template_tools_and_messages, ) @pytest.mark.parametrize( - "openai_request,expected_mistral_request", + "openai_request,expected_mistral_output", [ ( { @@ -41,19 +34,22 @@ } ], }, - ChatCompletionRequest( - messages=[ - UserMessage(content="What is the current local date and time?") + ( + [ + { + "role": "user", + "content": "What is the current local date and time?", + } ], - tools=[ - Tool( - type="function", - function=Function( - name="get_current_time", - description="Fetch the current local date and time.", - parameters={}, - ), - ) + [ + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + "parameters": {}, + }, + } ], ), ), @@ -71,39 +67,44 @@ "function": { "description": "Fetch the current local date and time.", "name": "get_current_time", - "parameters": None, + "parameters": {}, }, } ], }, - ChatCompletionRequest( - messages=[ - UserMessage(content="What is the current local date and time?") + ( + [ + { + "role": "user", + "content": "What is the current local date and time?", + } ], - tools=[ - Tool( - type="function", - function=Function( - name="get_current_time", - description="Fetch the current local date and time.", - parameters={}, - ), - ) + [ + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + "parameters": {}, + }, + } ], ), ), ], ) -def test_make_mistral_chat_completion_request(openai_request, expected_mistral_request): - actual_request = make_mistral_chat_completion_request( +def test_prepare_apply_chat_template_tools_and_messages( + openai_request, expected_mistral_output +): + actual_request = _prepare_apply_chat_template_tools_and_messages( openai_request["messages"], openai_request["tools"] ) - assert actual_request == expected_mistral_request + assert actual_request == expected_mistral_output # Tool use with list content and reasoning_content @pytest.mark.parametrize( - "openai_request,expected_mistral_request", + "openai_request,expected_mistral_output", [ ( { @@ -154,34 +155,40 @@ def test_make_mistral_chat_completion_request(openai_request, expected_mistral_r } ], }, - ChatCompletionRequest( - messages=[ - UserMessage(content="What's the weather in Paris?"), - AssistantMessage( - content=None, - tool_calls=[ - ToolCall( - id="call123", - function=FunctionCall( - name="get_weather", - arguments='{"city": "Paris"}', - ), - ) + ( + [ + { + "role": "user", + "content": "What's the weather in Paris?", + }, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Paris"}', + }, + } ], - ), - ToolMessage( - content="Rainy", - tool_call_id="call123", - name="get_weather", - ), + }, + { + "role": "tool", + "content": [{"type": "text", "text": "Rainy"}], + "name": "get_weather", + "tool_call_id": "call123", + }, ], - tools=[ - Tool( - type="function", - function=Function( - name="get_weather", - description="Gets the current weather in a city.", - parameters={ + [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Gets the current weather in a city.", + "parameters": { "type": "object", "properties": { "city": { @@ -191,17 +198,2012 @@ def test_make_mistral_chat_completion_request(openai_request, expected_mistral_r }, "required": ["city"], }, - ), - ) + }, + } ], ), ) ], ) -def test_make_mistral_chat_completion_request_list_content( - openai_request, expected_mistral_request +def test_prepare_apply_chat_template_tools_and_messages_list_content( + openai_request, expected_mistral_output ): - actual_request = make_mistral_chat_completion_request( + actual_request = _prepare_apply_chat_template_tools_and_messages( openai_request["messages"], openai_request["tools"] ) - assert actual_request == expected_mistral_request + assert actual_request == expected_mistral_output + + +def test_prepare_apply_chat_template_generation_prompt_and_continue(): + messages = [{"role": "assistant", "content": "Hello"}] + tools: list[dict[str, Any]] = [] + with pytest.raises(ValueError): + _prepare_apply_chat_template_tools_and_messages( + messages, tools, add_generation_prompt=True + ) + + messages = [{"role": "user", "content": "Hello"}] + out_messages, _ = _prepare_apply_chat_template_tools_and_messages( + messages, tools, add_generation_prompt=True + ) + assert out_messages == [{"role": "user", "content": "Hello"}] + + with pytest.raises(ValueError): + _prepare_apply_chat_template_tools_and_messages( + messages, tools, add_generation_prompt=True, continue_final_message=True + ) + + messages = [{"role": "assistant", "content": "Hello"}] + out_messages, _ = _prepare_apply_chat_template_tools_and_messages( + messages, tools, add_generation_prompt=False, continue_final_message=True + ) + assert out_messages == [{"role": "assistant", "content": "Hello"}] + + messages = [{"role": "user", "content": "Hello"}] + with pytest.raises(ValueError): + _prepare_apply_chat_template_tools_and_messages( + messages, tools, add_generation_prompt=False, continue_final_message=True + ) + + +@pytest.fixture(scope="module") +def mistral_tokenizer(request) -> MistralTokenizer: + return MistralTokenizer.from_pretrained(request.param) + + +@pytest.mark.parametrize( + "mistral_tokenizer", + ["mistralai/Mistral-7B-Instruct-v0.3", "mistralai/Magistral-Small-2509"], + indirect=True, +) +class TestMistralTokenizer: + def test_all_special_tokens(self, mistral_tokenizer: MistralTokenizer): + attributes = [ + mistral_tokenizer.all_special_tokens, + mistral_tokenizer.all_special_tokens_extended, + ] + + for attribute in attributes: + if mistral_tokenizer.is_tekken: + assert attribute == [ + "<unk>", + "<s>", + "</s>", + "[INST]", + "[/INST]", + "[AVAILABLE_TOOLS]", + "[/AVAILABLE_TOOLS]", + "[TOOL_RESULTS]", + "[/TOOL_RESULTS]", + "[TOOL_CALLS]", + "[IMG]", + "<pad>", + "[IMG_BREAK]", + "[IMG_END]", + "[PREFIX]", + "[MIDDLE]", + "[SUFFIX]", + "[SYSTEM_PROMPT]", + "[/SYSTEM_PROMPT]", + "[TOOL_CONTENT]", + ] + [f"<SPECIAL_{i}>" for i in range(20, 32)] + [ + "[ARGS]", + "[CALL_ID]", + "[THINK]", + "[/THINK]", + ] + [f"<SPECIAL_{i}>" for i in range(36, 1000)] + else: + assert attribute == [ + "<s>", + "</s>", + "[INST]", + "[/INST]", + "[TOOL_CALLS]", + "[AVAILABLE_TOOLS]", + "[/AVAILABLE_TOOLS]", + "[TOOL_RESULTS]", + "[/TOOL_RESULTS]", + ] + [f"[control_{i}]" for i in range(8, 769)] + + def get_vocab(self, mistral_tokenizer: MistralTokenizer): + assert ( + mistral_tokenizer.get_vocab() + == mistral_tokenizer.transformers_tokenizer.get_vocab() + ) + + def test_get_added_vocab(self, mistral_tokenizer: MistralTokenizer): + assert mistral_tokenizer.get_added_vocab() == {} + + def test_encode_one(self, mistral_tokenizer: MistralTokenizer): + token_ids = ( + [22177, 4304, 2662] if mistral_tokenizer.is_tekken else [23325, 2294, 1686] + ) + + assert mistral_tokenizer.encode_one("Hello world !") == token_ids + assert mistral_tokenizer.encode_one("Hello world !", max_length=1) == token_ids + assert ( + mistral_tokenizer.encode_one("Hello world !", truncation=True, max_length=1) + == token_ids[:-2] + ) + assert ( + mistral_tokenizer.encode_one( + "Hello world !", truncation=False, max_length=1 + ) + == token_ids + ) + + def test_encode(self, mistral_tokenizer: MistralTokenizer): + token_ids = ( + [1, 22177, 4304, 2662, 2] + if mistral_tokenizer.is_tekken + else [1, 23325, 2294, 1686, 2] + ) + + assert mistral_tokenizer.encode("Hello world !") == token_ids[:-1] + assert mistral_tokenizer.encode("Hello world !", max_length=3) == token_ids[:-2] + assert ( + mistral_tokenizer.encode("Hello world !", truncation=True, max_length=3) + == token_ids[:-2] + ) + assert ( + mistral_tokenizer.encode("Hello world !", truncation=False, max_length=3) + == token_ids[:-1] + ) + + assert ( + mistral_tokenizer.encode("Hello world !", add_special_tokens=True) + == token_ids + ) + assert ( + mistral_tokenizer.encode( + "Hello world !", add_special_tokens=True, max_length=3 + ) + == token_ids[:-2] + ) + assert ( + mistral_tokenizer.encode( + "Hello world !", add_special_tokens=True, truncation=False, max_length=3 + ) + == token_ids + ) + assert ( + mistral_tokenizer.encode("Hello world !", add_special_tokens=False) + == token_ids[1:-1] + ) + + @pytest.mark.parametrize( + "openai_request,add_generation_prompt,continue_final_message,expected_output,decoded_expected_output", + [ + ( + { + "messages": [ + { + "role": "user", + "content": "Hello world !", + } + ], + }, + True, + False, + ([1, 3, 23325, 2294, 1686, 4], [1, 3, 22177, 4304, 2662, 4]), + ("<s>[INST]▁Hello▁world▁![/INST]", ("<s>[INST]Hello world ![/INST]")), + ), + ( + { + "messages": [ + { + "role": "system", + "content": "I am an AI", + }, + { + "role": "user", + "content": "Hello world !", + }, + ], + }, + True, + False, + ( + [1, 3, 1083, 1605, 1164, 16875, 781, 781, 16998, 2294, 1686, 4], + [1, 17, 1073, 1855, 1420, 26554, 18, 3, 22177, 4304, 2662, 4], + ), + ( + "<s>[INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST]", + ( + "<s>[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][INST]Hello world ![/INST]" # noqa: E501 + ), + ), + ), + ( + { + "messages": [ + { + "role": "system", + "content": "I am an AI", + }, + { + "role": "user", + "content": "Hello world !", + }, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Gets the current weather in a city.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name", + } + }, + "required": ["city"], + }, + }, + } + ], + }, + True, + False, + ( + [ + 1, + 6, + 1501, + 7567, + 1891, + 2032, + 1113, + 3396, + 1316, + 1113, + 3396, + 2032, + 10598, + 1629, + 2032, + 1113, + 1295, + 29498, + 1537, + 1991, + 1316, + 1113, + 7286, + 2032, + 1113, + 2226, + 29481, + 1040, + 2636, + 8854, + 1065, + 1032, + 3758, + 9959, + 1113, + 12206, + 2032, + 10598, + 1891, + 2032, + 1113, + 3582, + 1316, + 1113, + 11491, + 2032, + 10598, + 19141, + 2032, + 10598, + 1891, + 2032, + 1113, + 2195, + 1316, + 1113, + 7286, + 2032, + 1113, + 1782, + 3758, + 1909, + 29507, + 11549, + 1113, + 11661, + 2032, + 8135, + 19141, + 3010, + 1743, + 10925, + 7, + 3, + 1083, + 1605, + 1164, + 16875, + 781, + 781, + 16998, + 2294, + 1686, + 4, + ], + [ + 1, + 17, + 1073, + 1855, + 1420, + 26554, + 18, + 5, + 1091, + 19227, + 4994, + 2811, + 1429, + 5165, + 1897, + 1429, + 5165, + 2811, + 16753, + 2391, + 2811, + 1429, + 1689, + 1095, + 45629, + 1897, + 1429, + 14653, + 2811, + 1429, + 1071, + 3083, + 1278, + 3519, + 17253, + 1294, + 1261, + 5970, + 39249, + 1429, + 26204, + 2811, + 16753, + 4994, + 2811, + 1429, + 6371, + 1897, + 1429, + 48649, + 2811, + 16753, + 29363, + 2811, + 16753, + 4994, + 2811, + 1429, + 3607, + 1897, + 1429, + 14653, + 2811, + 1429, + 1784, + 5970, + 2564, + 1034, + 47579, + 1429, + 15760, + 2811, + 12161, + 29363, + 4964, + 2821, + 27028, + 6, + 3, + 22177, + 4304, + 2662, + 4, + ], + ), + ( + '<s>[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"get_weather",▁"description":▁"Gets▁the▁current▁weather▁in▁a▁city.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"city":▁{"type":▁"string",▁"description":▁"The▁city▁name"}},▁"required":▁["city"]}}}][/AVAILABLE_TOOLS][INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST]', + ( + '<s>[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}][/AVAILABLE_TOOLS][INST]Hello world ![/INST]' # noqa: E501 + ), + ), + ), + ( + { + "messages": [ + { + "role": "system", + "content": "I am an AI", + }, + { + "role": "user", + "content": "Hello world !", + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "123456789", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Paris"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "123456789", + "content": '{"temperature": 20, "unit": "celsius"}', + }, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Gets the current weather in a city.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name", + } + }, + "required": ["city"], + }, + }, + } + ], + }, + True, + False, + ( + [ + 1, + 6, + 1501, + 7567, + 1891, + 2032, + 1113, + 3396, + 1316, + 1113, + 3396, + 2032, + 10598, + 1629, + 2032, + 1113, + 1295, + 29498, + 1537, + 1991, + 1316, + 1113, + 7286, + 2032, + 1113, + 2226, + 29481, + 1040, + 2636, + 8854, + 1065, + 1032, + 3758, + 9959, + 1113, + 12206, + 2032, + 10598, + 1891, + 2032, + 1113, + 3582, + 1316, + 1113, + 11491, + 2032, + 10598, + 19141, + 2032, + 10598, + 1891, + 2032, + 1113, + 2195, + 1316, + 1113, + 7286, + 2032, + 1113, + 1782, + 3758, + 1909, + 29507, + 11549, + 1113, + 11661, + 2032, + 8135, + 19141, + 3010, + 1743, + 10925, + 7, + 3, + 1083, + 1605, + 1164, + 16875, + 781, + 781, + 16998, + 2294, + 1686, + 4, + 5, + 1501, + 7567, + 1629, + 2032, + 1113, + 1295, + 29498, + 1537, + 1991, + 1316, + 1113, + 17452, + 2032, + 10598, + 19141, + 2032, + 1113, + 4684, + 1046, + 8474, + 1113, + 1081, + 2032, + 1113, + 29508, + 29518, + 29538, + 29549, + 29550, + 29552, + 29555, + 29551, + 29542, + 29507, + 10925, + 2, + 8, + 10598, + 4557, + 2032, + 10598, + 29475, + 17329, + 2032, + 29473, + 29518, + 29502, + 29493, + 1113, + 6074, + 2032, + 1113, + 29485, + 1958, + 3938, + 8474, + 1113, + 3613, + 29498, + 1081, + 2032, + 1113, + 29508, + 29518, + 29538, + 29549, + 29550, + 29552, + 29555, + 29551, + 29542, + 18163, + 9, + ], + [ + 1, + 17, + 1073, + 1855, + 1420, + 26554, + 18, + 5, + 1091, + 19227, + 4994, + 2811, + 1429, + 5165, + 1897, + 1429, + 5165, + 2811, + 16753, + 2391, + 2811, + 1429, + 1689, + 1095, + 45629, + 1897, + 1429, + 14653, + 2811, + 1429, + 1071, + 3083, + 1278, + 3519, + 17253, + 1294, + 1261, + 5970, + 39249, + 1429, + 26204, + 2811, + 16753, + 4994, + 2811, + 1429, + 6371, + 1897, + 1429, + 48649, + 2811, + 16753, + 29363, + 2811, + 16753, + 4994, + 2811, + 1429, + 3607, + 1897, + 1429, + 14653, + 2811, + 1429, + 1784, + 5970, + 2564, + 1034, + 47579, + 1429, + 15760, + 2811, + 12161, + 29363, + 4964, + 2821, + 27028, + 6, + 3, + 22177, + 4304, + 2662, + 4, + 9, + 1689, + 1095, + 45629, + 32, + 19227, + 29363, + 2811, + 1429, + 42572, + 46005, + 2, + 7, + 19227, + 113824, + 2811, + 1032, + 1050, + 1048, + 1044, + 1429, + 8979, + 2811, + 1429, + 1099, + 79092, + 46005, + 8, + ], + ), + ( + '<s>[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"get_weather",▁"description":▁"Gets▁the▁current▁weather▁in▁a▁city.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"city":▁{"type":▁"string",▁"description":▁"The▁city▁name"}},▁"required":▁["city"]}}}][/AVAILABLE_TOOLS][INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST][TOOL_CALLS]▁[{"name":▁"get_weather",▁"arguments":▁{"city":▁"Paris"},▁"id":▁"123456789"}]</s>[TOOL_RESULTS]▁{"content":▁{"temperature":▁20,▁"unit":▁"celsius"},▁"call_id":▁"123456789"}[/TOOL_RESULTS]', + ( + '<s>[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}][/AVAILABLE_TOOLS][INST]Hello world ![/INST][TOOL_CALLS]get_weather[ARGS]{"city": "Paris"}</s>[TOOL_RESULTS]{"temperature": 20, "unit": "celsius"}[/TOOL_RESULTS]' # noqa: E501 + ), + ), + ), + ( + { + "messages": [ + { + "role": "user", + "content": "Hello world !", + }, + { + "role": "assistant", + "content": "Hello ", + }, + ], + }, + False, + True, + ( + [1, 3, 23325, 2294, 1686, 4, 23325], + [1, 3, 22177, 4304, 2662, 4, 22177, 2], + ), + ( + "<s>[INST]▁Hello▁world▁![/INST]▁Hello", + ("<s>[INST]Hello world ![/INST]Hello</s>"), + ), + ), + ], + ) + def test_apply_chat_template( + self, + mistral_tokenizer: MistralTokenizer, + openai_request: dict[str, Any], + add_generation_prompt: bool, + continue_final_message: bool, + expected_output: tuple[list[int], list[int]], + decoded_expected_output: tuple[str, str], + ): + actual_output = mistral_tokenizer.apply_chat_template( + openai_request["messages"], + tools=openai_request.get("tools", []), + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + ) + decoded_actual_output = mistral_tokenizer.tokenizer.decode( + actual_output, SpecialTokenPolicy.KEEP + ) + + assert actual_output == expected_output[mistral_tokenizer.is_tekken] + assert ( + decoded_actual_output + == decoded_expected_output[mistral_tokenizer.is_tekken] + ) + + def test_apply_chat_template_error(self, mistral_tokenizer: MistralTokenizer): + messages = [{"role": "user", "content": "Hello world !"}] + + with pytest.raises(ValueError): + mistral_tokenizer.apply_chat_template( + messages, + tools=[], + add_generation_prompt=True, + continue_final_message=True, + ) + + with pytest.raises(ValueError): + mistral_tokenizer.apply_chat_template( + messages, + tools=[], + add_generation_prompt=False, + continue_final_message=True, + ) + + messages = [ + {"role": "user", "content": "Hello world !"}, + {"role": "assistant", "content": "Hello "}, + ] + with pytest.raises(ValueError): + mistral_tokenizer.apply_chat_template( + messages, + tools=[], + add_generation_prompt=True, + continue_final_message=False, + ) + + messages = [ + {"role": "user", "content": "Hello world !"}, + {"role": "assistant", "content": "Hello "}, + ] + with pytest.raises(InvalidMessageStructureException): + mistral_tokenizer.apply_chat_template( + messages, + tools=[], + add_generation_prompt=False, + continue_final_message=False, + ) + + @pytest.mark.parametrize( + "skip_special_tokens,expected_tokens", + ( + ( + False, + ( + "<s>[INST]▁Hello▁world▁![/INST]▁Hello</s>", + "<s>[INST]Hello world ![/INST]Hello</s>", + ), + ), + (True, ("Hello world ! Hello", "Hello world !Hello")), + ), + ) + def test_decode( + self, + mistral_tokenizer: MistralTokenizer, + skip_special_tokens: bool, + expected_tokens: tuple[str, str], + ): + ids = ( + [1, 3, 23325, 2294, 1686, 4, 23325, 2], + [1, 3, 22177, 4304, 2662, 4, 22177, 2], + ) + assert ( + mistral_tokenizer.decode( + ids[mistral_tokenizer.is_tekken], + skip_special_tokens=skip_special_tokens, + ) + == expected_tokens[mistral_tokenizer.is_tekken] + ) + + def test_convert_tokens_to_string(self, mistral_tokenizer: MistralTokenizer): + tokens = ( + [ + "<s>", + "[AVAILABLE_TOOLS]", + "▁[", + '{"', + "type", + '":', + '▁"', + "function", + '",', + '▁"', + "function", + '":', + '▁{"', + "name", + '":', + '▁"', + "get", + "_", + "we", + "ather", + '",', + '▁"', + "description", + '":', + '▁"', + "Get", + "s", + "▁the", + "▁current", + "▁weather", + "▁in", + "▁a", + "▁city", + '.",', + '▁"', + "parameters", + '":', + '▁{"', + "type", + '":', + '▁"', + "object", + '",', + '▁"', + "properties", + '":', + '▁{"', + "city", + '":', + '▁{"', + "type", + '":', + '▁"', + "string", + '",', + '▁"', + "description", + '":', + '▁"', + "The", + "▁city", + "▁name", + '"', + "}},", + '▁"', + "required", + '":', + '▁["', + "city", + '"]', + "}}", + "}]", + "[/AVAILABLE_TOOLS]", + "[INST]", + "▁I", + "▁am", + "▁an", + "▁AI", + "<0x0A>", + "<0x0A>", + "Hello", + "▁world", + "▁!", + "[/INST]", + "[TOOL_CALLS]", + "▁[", + '{"', + "name", + '":', + '▁"', + "get", + "_", + "we", + "ather", + '",', + '▁"', + "arguments", + '":', + '▁{"', + "city", + '":', + '▁"', + "Par", + "is", + '"},', + '▁"', + "id", + '":', + '▁"', + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '"', + "}]", + "</s>", + "[TOOL_RESULTS]", + '▁{"', + "content", + '":', + '▁{"', + "t", + "emperature", + '":', + "▁", + "2", + "0", + ",", + '▁"', + "unit", + '":', + '▁"', + "c", + "els", + "ius", + '"},', + '▁"', + "call", + "_", + "id", + '":', + '▁"', + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '"}', + "[/TOOL_RESULTS]", + ], + [ + "<s>", + "[SYSTEM_PROMPT]", + "I", + " am", + " an", + " AI", + "[/SYSTEM_PROMPT]", + "[AVAILABLE_TOOLS]", + "[", + '{"', + "type", + '":', + ' "', + "function", + '",', + ' "', + "function", + '":', + ' {"', + "name", + '":', + ' "', + "get", + "_", + "weather", + '",', + ' "', + "description", + '":', + ' "', + "G", + "ets", + " the", + " current", + " weather", + " in", + " a", + " city", + '.",', + ' "', + "parameters", + '":', + ' {"', + "type", + '":', + ' "', + "object", + '",', + ' "', + "properties", + '":', + ' {"', + "city", + '":', + ' {"', + "type", + '":', + ' "', + "string", + '",', + ' "', + "description", + '":', + ' "', + "The", + " city", + " name", + '"', + "}},", + ' "', + "required", + '":', + ' ["', + "city", + '"]', + "}}", + "}]", + "[/AVAILABLE_TOOLS]", + "[INST]", + "Hello", + " world", + " !", + "[/INST]", + "[TOOL_CALLS]", + "get", + "_", + "weather", + "[ARGS]", + '{"', + "city", + '":', + ' "', + "Paris", + '"}', + "</s>", + "[TOOL_RESULTS]", + '{"', + "temperature", + '":', + " ", + "2", + "0", + ",", + ' "', + "unit", + '":', + ' "', + "c", + "elsius", + '"}', + "[/TOOL_RESULTS]", + ], + ) + + expected_strings = ( + '[{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}] I am an AI\n\nHello world ![TOOL_CALLS][{"name": "get_weather", "arguments": {"city": "Paris"}, "id": "123456789"}] {"content": {"temperature": 20, "unit": "celsius"}, "call_id": "123456789"}', # noqa: E501 + 'I am an AI[{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}]Hello world ![TOOL_CALLS]get_weather{"city": "Paris"}{"temperature": 20, "unit": "celsius"}', # noqa: E501 + ) + + assert ( + mistral_tokenizer.convert_tokens_to_string( + tokens[mistral_tokenizer.is_tekken] + ) + == expected_strings[mistral_tokenizer.is_tekken] + ) + + @pytest.mark.parametrize( + "skip_special_tokens,tuple_expected_tokens", + ( + ( + True, + ( + [ + "▁[", + '{"', + "type", + '":', + '▁"', + "function", + '",', + '▁"', + "function", + '":', + '▁{"', + "name", + '":', + '▁"', + "get", + "_", + "we", + "ather", + '",', + '▁"', + "description", + '":', + '▁"', + "Get", + "s", + "▁the", + "▁current", + "▁weather", + "▁in", + "▁a", + "▁city", + '.",', + '▁"', + "parameters", + '":', + '▁{"', + "type", + '":', + '▁"', + "object", + '",', + '▁"', + "properties", + '":', + '▁{"', + "city", + '":', + '▁{"', + "type", + '":', + '▁"', + "string", + '",', + '▁"', + "description", + '":', + '▁"', + "The", + "▁city", + "▁name", + '"', + "}},", + '▁"', + "required", + '":', + '▁["', + "city", + '"]', + "}}", + "}]", + "▁I", + "▁am", + "▁an", + "▁AI", + "<0x0A>", + "<0x0A>", + "Hello", + "▁world", + "▁!", + "[TOOL_CALLS]", + "▁[", + '{"', + "name", + '":', + '▁"', + "get", + "_", + "we", + "ather", + '",', + '▁"', + "arguments", + '":', + '▁{"', + "city", + '":', + '▁"', + "Par", + "is", + '"},', + '▁"', + "id", + '":', + '▁"', + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '"', + "}]", + '▁{"', + "content", + '":', + '▁{"', + "t", + "emperature", + '":', + "▁", + "2", + "0", + ",", + '▁"', + "unit", + '":', + '▁"', + "c", + "els", + "ius", + '"},', + '▁"', + "call", + "_", + "id", + '":', + '▁"', + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '"}', + ], + [ + "I", + " am", + " an", + " AI", + "[", + '{"', + "type", + '":', + ' "', + "function", + '",', + ' "', + "function", + '":', + ' {"', + "name", + '":', + ' "', + "get", + "_", + "weather", + '",', + ' "', + "description", + '":', + ' "', + "G", + "ets", + " the", + " current", + " weather", + " in", + " a", + " city", + '.",', + ' "', + "parameters", + '":', + ' {"', + "type", + '":', + ' "', + "object", + '",', + ' "', + "properties", + '":', + ' {"', + "city", + '":', + ' {"', + "type", + '":', + ' "', + "string", + '",', + ' "', + "description", + '":', + ' "', + "The", + " city", + " name", + '"', + "}},", + ' "', + "required", + '":', + ' ["', + "city", + '"]', + "}}", + "}]", + "Hello", + " world", + " !", + "[TOOL_CALLS]", + "get", + "_", + "weather", + '{"', + "city", + '":', + ' "', + "Paris", + '"}', + '{"', + "temperature", + '":', + " ", + "2", + "0", + ",", + ' "', + "unit", + '":', + ' "', + "c", + "elsius", + '"}', + ], + ), + ), + ( + False, + ( + [ + "<s>", + "[AVAILABLE_TOOLS]", + "▁[", + '{"', + "type", + '":', + '▁"', + "function", + '",', + '▁"', + "function", + '":', + '▁{"', + "name", + '":', + '▁"', + "get", + "_", + "we", + "ather", + '",', + '▁"', + "description", + '":', + '▁"', + "Get", + "s", + "▁the", + "▁current", + "▁weather", + "▁in", + "▁a", + "▁city", + '.",', + '▁"', + "parameters", + '":', + '▁{"', + "type", + '":', + '▁"', + "object", + '",', + '▁"', + "properties", + '":', + '▁{"', + "city", + '":', + '▁{"', + "type", + '":', + '▁"', + "string", + '",', + '▁"', + "description", + '":', + '▁"', + "The", + "▁city", + "▁name", + '"', + "}},", + '▁"', + "required", + '":', + '▁["', + "city", + '"]', + "}}", + "}]", + "[/AVAILABLE_TOOLS]", + "[INST]", + "▁I", + "▁am", + "▁an", + "▁AI", + "<0x0A>", + "<0x0A>", + "Hello", + "▁world", + "▁!", + "[/INST]", + "[TOOL_CALLS]", + "▁[", + '{"', + "name", + '":', + '▁"', + "get", + "_", + "we", + "ather", + '",', + '▁"', + "arguments", + '":', + '▁{"', + "city", + '":', + '▁"', + "Par", + "is", + '"},', + '▁"', + "id", + '":', + '▁"', + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '"', + "}]", + "</s>", + "[TOOL_RESULTS]", + '▁{"', + "content", + '":', + '▁{"', + "t", + "emperature", + '":', + "▁", + "2", + "0", + ",", + '▁"', + "unit", + '":', + '▁"', + "c", + "els", + "ius", + '"},', + '▁"', + "call", + "_", + "id", + '":', + '▁"', + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + '"}', + "[/TOOL_RESULTS]", + ], + [ + "<s>", + "[SYSTEM_PROMPT]", + "I", + " am", + " an", + " AI", + "[/SYSTEM_PROMPT]", + "[AVAILABLE_TOOLS]", + "[", + '{"', + "type", + '":', + ' "', + "function", + '",', + ' "', + "function", + '":', + ' {"', + "name", + '":', + ' "', + "get", + "_", + "weather", + '",', + ' "', + "description", + '":', + ' "', + "G", + "ets", + " the", + " current", + " weather", + " in", + " a", + " city", + '.",', + ' "', + "parameters", + '":', + ' {"', + "type", + '":', + ' "', + "object", + '",', + ' "', + "properties", + '":', + ' {"', + "city", + '":', + ' {"', + "type", + '":', + ' "', + "string", + '",', + ' "', + "description", + '":', + ' "', + "The", + " city", + " name", + '"', + "}},", + ' "', + "required", + '":', + ' ["', + "city", + '"]', + "}}", + "}]", + "[/AVAILABLE_TOOLS]", + "[INST]", + "Hello", + " world", + " !", + "[/INST]", + "[TOOL_CALLS]", + "get", + "_", + "weather", + "[ARGS]", + '{"', + "city", + '":', + ' "', + "Paris", + '"}', + "</s>", + "[TOOL_RESULTS]", + '{"', + "temperature", + '":', + " ", + "2", + "0", + ",", + ' "', + "unit", + '":', + ' "', + "c", + "elsius", + '"}', + "[/TOOL_RESULTS]", + ], + ), + ), + ), + ) + def test_convert_ids_to_tokens( + self, + mistral_tokenizer: MistralTokenizer, + skip_special_tokens: bool, + tuple_expected_tokens: tuple[list[str], list[str]], + ): + tuple_ids = ( + [ + 1, + 6, + 1501, + 7567, + 1891, + 2032, + 1113, + 3396, + 1316, + 1113, + 3396, + 2032, + 10598, + 1629, + 2032, + 1113, + 1295, + 29498, + 1537, + 1991, + 1316, + 1113, + 7286, + 2032, + 1113, + 2226, + 29481, + 1040, + 2636, + 8854, + 1065, + 1032, + 3758, + 9959, + 1113, + 12206, + 2032, + 10598, + 1891, + 2032, + 1113, + 3582, + 1316, + 1113, + 11491, + 2032, + 10598, + 19141, + 2032, + 10598, + 1891, + 2032, + 1113, + 2195, + 1316, + 1113, + 7286, + 2032, + 1113, + 1782, + 3758, + 1909, + 29507, + 11549, + 1113, + 11661, + 2032, + 8135, + 19141, + 3010, + 1743, + 10925, + 7, + 3, + 1083, + 1605, + 1164, + 16875, + 781, + 781, + 16998, + 2294, + 1686, + 4, + 5, + 1501, + 7567, + 1629, + 2032, + 1113, + 1295, + 29498, + 1537, + 1991, + 1316, + 1113, + 17452, + 2032, + 10598, + 19141, + 2032, + 1113, + 4684, + 1046, + 8474, + 1113, + 1081, + 2032, + 1113, + 29508, + 29518, + 29538, + 29549, + 29550, + 29552, + 29555, + 29551, + 29542, + 29507, + 10925, + 2, + 8, + 10598, + 4557, + 2032, + 10598, + 29475, + 17329, + 2032, + 29473, + 29518, + 29502, + 29493, + 1113, + 6074, + 2032, + 1113, + 29485, + 1958, + 3938, + 8474, + 1113, + 3613, + 29498, + 1081, + 2032, + 1113, + 29508, + 29518, + 29538, + 29549, + 29550, + 29552, + 29555, + 29551, + 29542, + 18163, + 9, + ], + [ + 1, + 17, + 1073, + 1855, + 1420, + 26554, + 18, + 5, + 1091, + 19227, + 4994, + 2811, + 1429, + 5165, + 1897, + 1429, + 5165, + 2811, + 16753, + 2391, + 2811, + 1429, + 1689, + 1095, + 45629, + 1897, + 1429, + 14653, + 2811, + 1429, + 1071, + 3083, + 1278, + 3519, + 17253, + 1294, + 1261, + 5970, + 39249, + 1429, + 26204, + 2811, + 16753, + 4994, + 2811, + 1429, + 6371, + 1897, + 1429, + 48649, + 2811, + 16753, + 29363, + 2811, + 16753, + 4994, + 2811, + 1429, + 3607, + 1897, + 1429, + 14653, + 2811, + 1429, + 1784, + 5970, + 2564, + 1034, + 47579, + 1429, + 15760, + 2811, + 12161, + 29363, + 4964, + 2821, + 27028, + 6, + 3, + 22177, + 4304, + 2662, + 4, + 9, + 1689, + 1095, + 45629, + 32, + 19227, + 29363, + 2811, + 1429, + 42572, + 46005, + 2, + 7, + 19227, + 113824, + 2811, + 1032, + 1050, + 1048, + 1044, + 1429, + 8979, + 2811, + 1429, + 1099, + 79092, + 46005, + 8, + ], + ) + + ids = tuple_ids[mistral_tokenizer.is_tekken] + expected_tokens = tuple_expected_tokens[mistral_tokenizer.is_tekken] + actual_tokens = mistral_tokenizer.convert_ids_to_tokens( + ids, skip_special_tokens=skip_special_tokens + ) + assert actual_tokens == expected_tokens diff --git a/tests/tokenization/test_tokenizer_registry.py b/tests/tokenization/test_tokenizer_registry.py index de67c3e798c4..d89737888aa2 100644 --- a/tests/tokenization/test_tokenizer_registry.py +++ b/tests/tokenization/test_tokenizer_registry.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer_base import TokenizerBase, TokenizerRegistry @@ -61,11 +61,11 @@ def truncation_side(self) -> str: def __call__( self, - text: Union[str, list[str], list[int]], - text_pair: Optional[str] = None, + text: str | list[str] | list[int], + text_pair: str | None = None, add_special_tokens: bool = False, truncation: bool = False, - max_length: Optional[int] = None, + max_length: int | None = None, ): raise NotImplementedError() @@ -79,17 +79,17 @@ def encode_one( self, text: str, truncation: bool = False, - max_length: Optional[int] = None, + max_length: int | None = None, ) -> list[int]: raise NotImplementedError() - def encode(self, text: str, add_special_tokens: Optional[bool] = None) -> list[int]: + def encode(self, text: str, add_special_tokens: bool | None = None) -> list[int]: raise NotImplementedError() def apply_chat_template( self, messages: list["ChatCompletionMessageParam"], - tools: Optional[list[dict[str, Any]]] = None, + tools: list[dict[str, Any]] | None = None, **kwargs, ) -> list[int]: raise NotImplementedError() @@ -97,9 +97,7 @@ def apply_chat_template( def convert_tokens_to_string(self, tokens: list[str]) -> str: raise NotImplementedError() - def decode( - self, ids: Union[list[int], int], skip_special_tokens: bool = True - ) -> str: + def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str: raise NotImplementedError() def convert_ids_to_tokens( diff --git a/tests/tool_use/mistral/utils.py b/tests/tool_use/mistral/utils.py index 13a234f8e26b..4d772ba63793 100644 --- a/tests/tool_use/mistral/utils.py +++ b/tests/tool_use/mistral/utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional from typing_extensions import TypedDict @@ -9,9 +8,9 @@ class ServerConfig(TypedDict, total=False): model: str arguments: list[str] - system_prompt: Optional[str] - supports_parallel: Optional[bool] - supports_rocm: Optional[bool] + system_prompt: str | None + supports_parallel: bool | None + supports_rocm: bool | None ARGS: list[str] = ["--max-model-len", "1024"] diff --git a/tests/tool_use/test_ernie45_moe_tool_parser.py b/tests/tool_use/test_ernie45_moe_tool_parser.py new file mode 100644 index 000000000000..0862d14812d7 --- /dev/null +++ b/tests/tool_use/test_ernie45_moe_tool_parser.py @@ -0,0 +1,359 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +import json +from collections.abc import Generator + +import pytest + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + FunctionCall, + ToolCall, +) +from vllm.entrypoints.openai.tool_parsers import Ernie45ToolParser +from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer + +# Use a common model that is likely to be available +MODEL = "baidu/ERNIE-4.5-21B-A3B-Thinking" + + +@pytest.fixture(scope="module") +def ernie45_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True) + + +@pytest.fixture +def ernie45_tool_parser(ernie45_tokenizer): + return Ernie45ToolParser(ernie45_tokenizer) + + +def assert_tool_calls( + actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] +): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): + assert isinstance(actual_tool_call.id, str) + assert len(actual_tool_call.id) > 0 + + assert actual_tool_call.type == "function" + assert actual_tool_call.function.name == expected_tool_call.function.name + # Compare arguments as JSON objects to handle formatting differences + actual_args = json.loads(actual_tool_call.function.arguments) + expected_args = json.loads(expected_tool_call.function.arguments) + assert actual_args == expected_args + + +def test_extract_tool_calls_no_tools(ernie45_tool_parser): + model_output = "This is a test" + extracted_tool_calls = ernie45_tool_parser.extract_tool_calls( + model_output, request=None + ) # type: ignore[arg-type] + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "single_tool_call", + "multiple_tool_calls", + "tool_call_with_content_before", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """<tool_call> +{"name": "get_current_temperature", "arguments": {"location": "Beijing"}} +</tool_call> +""", + [ + ToolCall( + function=FunctionCall( + name="get_current_temperature", + arguments=json.dumps( + { + "location": "Beijing", + } + ), + ) + ) + ], + None, + ), + ( + """<tool_call> +{"name": "get_current_temperature", "arguments": {"location": "Beijing"}} +</tool_call> +<tool_call> +{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}} +</tool_call> +""", + [ + ToolCall( + function=FunctionCall( + name="get_current_temperature", + arguments=json.dumps( + { + "location": "Beijing", + } + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_temperature_unit", + arguments=json.dumps( + { + "location": "Guangzhou", + "unit": "c", + } + ), + ) + ), + ], + None, + ), + ( + """I need to call two tools to handle these two issues separately. +</think> + +<tool_call> +{"name": "get_current_temperature", "arguments": {"location": "Beijing"}} +</tool_call> +<tool_call> +{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}} +</tool_call> +""", + [ + ToolCall( + function=FunctionCall( + name="get_current_temperature", + arguments=json.dumps( + { + "location": "Beijing", + } + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_temperature_unit", + arguments=json.dumps( + { + "location": "Guangzhou", + "unit": "c", + } + ), + ) + ), + ], + "I need to call two tools to handle these two issues separately.\n</think>", + ), + ], +) +def test_extract_tool_calls( + ernie45_tool_parser, model_output, expected_tool_calls, expected_content +): + extracted_tool_calls = ernie45_tool_parser.extract_tool_calls( + model_output, request=None + ) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +def stream_delta_message_generator( + ernie45_tool_parser: Ernie45ToolParser, + ernie45_tokenizer: AnyTokenizer, + model_output: str, + request: ChatCompletionRequest | None = None, +) -> Generator[DeltaMessage, None, None]: + all_token_ids = ernie45_tokenizer.encode(model_output, add_special_tokens=False) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[: i + 1] + + (new_tokens, delta_text, new_prefix_offset, new_read_offset) = ( + detokenize_incrementally( + tokenizer=ernie45_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + ) + + current_text = previous_text + delta_text + + delta_message = ernie45_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=request, + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = ( + previous_tokens + new_tokens if previous_tokens else new_tokens + ) + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + +@pytest.mark.parametrize( + ids=[ + "single_tool_call", + "multiple_tool_calls", + "tool_call_with_content_before", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """<tool_call> +{"name": "get_current_temperature", "arguments": {"location": "Beijing"}} +</tool_call> +""", + [ + ToolCall( + function=FunctionCall( + name="get_current_temperature", + arguments=json.dumps( + { + "location": "Beijing", + } + ), + ) + ) + ], + None, + ), + ( + """<tool_call> +{"name": "get_current_temperature", "arguments": {"location": "Beijing"}} +</tool_call> +<tool_call> +{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}} +</tool_call> +""", + [ + ToolCall( + function=FunctionCall( + name="get_current_temperature", + arguments=json.dumps( + { + "location": "Beijing", + } + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_temperature_unit", + arguments=json.dumps( + { + "location": "Guangzhou", + "unit": "c", + } + ), + ) + ), + ], + None, + ), + ( + """I need to call two tools to handle these two issues separately. +</think> + +<tool_call> +{"name": "get_current_temperature", "arguments": {"location": "Beijing"}} +</tool_call> +<tool_call> +{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}} +</tool_call> +""", + [ + ToolCall( + function=FunctionCall( + name="get_current_temperature", + arguments=json.dumps( + { + "location": "Beijing", + } + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_temperature_unit", + arguments=json.dumps( + { + "location": "Guangzhou", + "unit": "c", + } + ), + ) + ), + ], + "I need to call two tools to handle these two issues separately.\n</think>", + ), + ], +) +def test_extract_tool_calls_streaming_incremental( + ernie45_tool_parser, + ernie45_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + """Verify the Ernie45 Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501 + request = ChatCompletionRequest(model=MODEL, messages=[], tools=[]) + + tool_calls_dict = {} + for delta_message in stream_delta_message_generator( + ernie45_tool_parser, ernie45_tokenizer, model_output, request + ): + if ( + delta_message.role is None + and delta_message.content is None + and delta_message.reasoning_content is None + and len(delta_message.tool_calls) == 0 + ): + continue + tool_calls = delta_message.tool_calls + for tool_call_chunk in tool_calls: + index = tool_call_chunk.index + if index not in tool_calls_dict: + if tool_call_chunk.function.arguments is None: + tool_call_chunk.function.arguments = "" + tool_calls_dict[index] = tool_call_chunk + else: + tool_calls_dict[ + index + ].function.arguments += tool_call_chunk.function.arguments + actual_tool_calls = list(tool_calls_dict.values()) + + assert len(actual_tool_calls) > 0 + # check tool call format + assert_tool_calls(actual_tool_calls, expected_tool_calls) diff --git a/tests/tool_use/test_jamba_tool_parser.py b/tests/tool_use/test_jamba_tool_parser.py index 44d42bbd72b0..6dcdd5ba2ce7 100644 --- a/tests/tool_use/test_jamba_tool_parser.py +++ b/tests/tool_use/test_jamba_tool_parser.py @@ -3,7 +3,6 @@ import json from collections.abc import Generator -from typing import Optional import partial_json_parser import pytest @@ -248,7 +247,7 @@ def test_extract_tool_calls_streaming( function_names: list[str] = [] function_args_strs: list[str] = [] tool_call_idx: int = -1 - tool_call_ids: list[Optional[str]] = [] + tool_call_ids: list[str | None] = [] for delta_message in stream_delta_message_generator( jamba_tool_parser, jamba_tokenizer, model_output diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py index 159966365ec4..9af94a6a64a2 100644 --- a/tests/tool_use/test_parallel_tool_calls.py +++ b/tests/tool_use/test_parallel_tool_calls.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json -from typing import Optional import openai import pytest @@ -80,7 +79,7 @@ async def test_parallel_tool_calls( stream=True, ) - role_name: Optional[str] = None + role_name: str | None = None finish_reason_count: int = 0 tool_call_names: list[str] = [] diff --git a/tests/tool_use/test_qwen3coder_tool_parser.py b/tests/tool_use/test_qwen3coder_tool_parser.py index 20fa3b08c7b9..93ef1049fc07 100644 --- a/tests/tool_use/test_qwen3coder_tool_parser.py +++ b/tests/tool_use/test_qwen3coder_tool_parser.py @@ -3,7 +3,6 @@ import json from collections.abc import Generator -from typing import Optional import pytest @@ -41,7 +40,7 @@ def qwen3_xml_tool_parser(qwen3_tokenizer): return Qwen3XMLToolParser(qwen3_tokenizer) -@pytest.fixture(params=["original", "xml"]) +@pytest.fixture(params=["xml"]) def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, request): """Parameterized fixture that provides both parser types for testing""" if request.param == "original": @@ -107,7 +106,7 @@ def stream_delta_message_generator( qwen3_tool_parser, qwen3_tokenizer: AnyTokenizer, model_output: str, - request: Optional[ChatCompletionRequest] = None, + request: ChatCompletionRequest | None = None, ) -> Generator[DeltaMessage, None, None]: all_token_ids = qwen3_tokenizer.encode(model_output, add_special_tokens=False) @@ -665,6 +664,9 @@ def test_extract_tool_calls_streaming( # Verify we got all expected tool calls assert len(tool_states) == len(expected_tool_calls) + assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == len( + expected_tool_calls + ) # Verify each tool call for idx, expected_tool in enumerate(expected_tool_calls): @@ -781,9 +783,10 @@ def test_extract_tool_calls_streaming_missing_closing_tag( # Verify content was streamed assert "Let me check the weather for you:" in other_content - # Verify we got the tool call assert len(tool_states) == 1 + assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == 1 + state = tool_states[0] assert state["id"] is not None assert state["type"] == "function" @@ -893,3 +896,83 @@ def test_extract_tool_calls_complex_type_with_single_quote( args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) assert args["obj_param"] == {"key": "value"} + + +def test_extract_tool_calls_streaming_missing_opening_tag( + qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools +): + """Test streaming with missing opening <tool_call> tag + + This tests that the streaming parser correctly handles + tool calls that start directly with <function=...> + """ + model_output = """I'll check the weather for you. + +<function=get_current_weather> +<parameter=city> +Dallas +</parameter> +<parameter=state> +TX +</parameter> +<parameter=unit> +fahrenheit +</parameter> +</function> +</tool_call>""" + + request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) + + other_content = "" + tool_states = {} + + for delta_message in stream_delta_message_generator( + qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request + ): + if delta_message.content: + other_content += delta_message.content + + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + idx = tool_call.index + + if idx not in tool_states: + tool_states[idx] = { + "id": None, + "name": None, + "arguments": "", + "type": None, + } + + if tool_call.id: + tool_states[idx]["id"] = tool_call.id + + if tool_call.type: + assert tool_call.type == "function" + tool_states[idx]["type"] = tool_call.type + + if tool_call.function: + if tool_call.function.name: + tool_states[idx]["name"] = tool_call.function.name + + if tool_call.function.arguments is not None: + tool_states[idx]["arguments"] += tool_call.function.arguments + + # Verify content was streamed + assert "I'll check the weather for you." in other_content + + # Verify we got the tool call + assert len(tool_states) == 1 + assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == 1 + + state = tool_states[0] + assert state["id"] is not None + assert state["type"] == "function" + assert state["name"] == "get_current_weather" + + # Verify arguments were parsed correctly despite missing opening tag + assert state["arguments"] is not None + args = json.loads(state["arguments"]) + assert args["city"] == "Dallas" + assert args["state"] == "TX" + assert args["unit"] == "fahrenheit" diff --git a/tests/tool_use/test_seed_oss_tool_parser.py b/tests/tool_use/test_seed_oss_tool_parser.py index eddb5a9b9f5e..1133b949f227 100644 --- a/tests/tool_use/test_seed_oss_tool_parser.py +++ b/tests/tool_use/test_seed_oss_tool_parser.py @@ -4,7 +4,6 @@ import json from collections.abc import Generator -from typing import Optional import pytest @@ -259,7 +258,7 @@ def stream_delta_message_generator( seed_oss_tool_parser: SeedOssToolParser, seed_oss_tokenizer: AnyTokenizer, model_output: str, - request: Optional[ChatCompletionRequest] = None, + request: ChatCompletionRequest | None = None, ) -> Generator[DeltaMessage, None, None]: all_token_ids = seed_oss_tokenizer.encode(model_output, add_special_tokens=False) diff --git a/tests/tool_use/test_tool_calls.py b/tests/tool_use/test_tool_calls.py index 64186aaac6a7..6614b6415a04 100644 --- a/tests/tool_use/test_tool_calls.py +++ b/tests/tool_use/test_tool_calls.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json -from typing import Optional import openai import pytest @@ -58,10 +57,10 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): assert stop_reason == "tool_calls" - function_name: Optional[str] = None + function_name: str | None = None function_args_str: str = "" - tool_call_id: Optional[str] = None - role_name: Optional[str] = None + tool_call_id: str | None = None + role_name: str | None = None finish_reason_count: int = 0 # make the same request, streaming diff --git a/tests/tool_use/test_tool_choice_required.py b/tests/tool_use/test_tool_choice_required.py index d52c141f6210..d5572cfbebe3 100644 --- a/tests/tool_use/test_tool_choice_required.py +++ b/tests/tool_use/test_tool_choice_required.py @@ -9,10 +9,10 @@ from pydantic import TypeAdapter from vllm.entrypoints.openai.protocol import ( - ChatCompletionRequest, ChatCompletionToolsParam, ) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.tool_parsers.utils import get_json_schema_from_tools pytestmark = pytest.mark.cpu_test @@ -67,8 +67,9 @@ def _compile_and_check( tools: list[ChatCompletionToolsParam], sample_output, should_match: bool ): - self = MagicMock(tool_choice="required", tools=tools) - schema = ChatCompletionRequest._get_json_schema_from_tool(self) + # self = MagicMock(tool_choice="required", tools=tools) + # schema = ChatCompletionRequest._get_json_schema_from_tool(self) + schema = get_json_schema_from_tools(tools=tools, tool_choice="required") assert isinstance(schema, dict) # use build_regex_from_schema used in JSONLogitsProcessor to create Guide diff --git a/tests/tool_use/test_xlam_tool_parser.py b/tests/tool_use/test_xlam_tool_parser.py index bdac878db4e7..8c27b2911f8f 100644 --- a/tests/tool_use/test_xlam_tool_parser.py +++ b/tests/tool_use/test_xlam_tool_parser.py @@ -3,7 +3,6 @@ import json from collections.abc import Generator -from typing import Optional import pytest @@ -52,7 +51,7 @@ def stream_delta_message_generator( xlam_tool_parser: xLAMToolParser, xlam_tokenizer: AnyTokenizer, model_output: str, - request: Optional[ChatCompletionRequest] = None, + request: ChatCompletionRequest | None = None, ) -> Generator[DeltaMessage, None, None]: all_token_ids = xlam_tokenizer.encode(model_output, add_special_tokens=False) diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index 835d07608e40..38def6f874d7 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import deepcopy -from typing import Any, Optional +from typing import Any from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam from typing_extensions import TypedDict @@ -13,10 +13,10 @@ class ServerConfig(TypedDict, total=False): model: str arguments: list[str] - system_prompt: Optional[str] - supports_parallel: Optional[bool] - supports_rocm: Optional[bool] - extended: Optional[bool] # tests do not run in CI automatically + system_prompt: str | None + supports_parallel: bool | None + supports_rocm: bool | None + extended: bool | None # tests do not run in CI automatically def patch_system_prompt( diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 102e5ddf16d6..cf455ff3edbd 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -3,7 +3,7 @@ import pytest -from vllm.config import CompilationLevel +from vllm.config import CompilationMode from ..utils import compare_two_settings @@ -21,13 +21,13 @@ def test_custom_dispatcher(monkeypatch: pytest.MonkeyPatch): "--max-model-len=256", "--max-num-seqs=32", "--enforce-eager", - f"-O{CompilationLevel.DYNAMO_ONCE}", + f"-O{CompilationMode.DYNAMO_TRACE_ONCE}", ], arg2=[ "--max-model-len=256", "--max-num-seqs=32", "--enforce-eager", - f"-O{CompilationLevel.DYNAMO_AS_IS}", + f"-O{CompilationMode.STOCK_TORCH_COMPILE}", ], env1={}, env2={}, diff --git a/tests/transformers_utils/test_config_parser_registry.py b/tests/transformers_utils/test_config_parser_registry.py index 9372cb9d46d3..0931bd734f8f 100644 --- a/tests/transformers_utils/test_config_parser_registry.py +++ b/tests/transformers_utils/test_config_parser_registry.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from pathlib import Path -from typing import Optional, Union import pytest from transformers import PretrainedConfig @@ -15,10 +14,10 @@ class CustomConfigParser(ConfigParserBase): def parse( self, - model: Union[str, Path], + model: str | Path, trust_remote_code: bool, - revision: Optional[str] = None, - code_revision: Optional[str] = None, + revision: str | None = None, + code_revision: str | None = None, **kwargs, ) -> tuple[dict, PretrainedConfig]: raise NotImplementedError diff --git a/tests/utils.py b/tests/utils.py index b853542c241f..fb7614dd7fbc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,6 +6,7 @@ import copy import functools import importlib +import itertools import json import os import random @@ -15,12 +16,14 @@ import tempfile import time import warnings +from collections.abc import Callable, Iterable from contextlib import ExitStack, contextmanager, suppress from multiprocessing import Process from pathlib import Path -from typing import Any, Callable, Literal, Optional, Union +from typing import Any, Literal from unittest.mock import patch +import anthropic import cloudpickle import httpx import openai @@ -44,10 +47,10 @@ from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import ( FlexibleArgumentParser, - GB_bytes, - cuda_device_count_stateless, - get_open_port, ) +from vllm.utils.mem_constants import GB_bytes +from vllm.utils.network_utils import get_open_port +from vllm.utils.torch_utils import cuda_device_count_stateless if current_platform.is_rocm(): from amdsmi import ( @@ -94,7 +97,7 @@ class RemoteOpenAIServer: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key def _start_server( - self, model: str, vllm_serve_args: list[str], env_dict: Optional[dict[str, str]] + self, model: str, vllm_serve_args: list[str], env_dict: dict[str, str] | None ) -> None: """Subclasses override this method to customize server process launch""" env = os.environ.copy() @@ -117,11 +120,11 @@ def __init__( model: str, vllm_serve_args: list[str], *, - env_dict: Optional[dict[str, str]] = None, - seed: Optional[int] = 0, + env_dict: dict[str, str] | None = None, + seed: int | None = 0, auto_port: bool = True, - max_wait_seconds: Optional[float] = None, - override_hf_configs: Optional[dict[str, Any]] = None, + max_wait_seconds: float | None = None, + override_hf_configs: dict[str, Any] | None = None, ) -> None: if auto_port: if "-p" in vllm_serve_args or "--port" in vllm_serve_args: @@ -156,7 +159,7 @@ def __init__( self.host = None self.port = None else: - self.host = str(args.host or "localhost") + self.host = str(args.host or "127.0.0.1") self.port = int(args.port) self.show_hidden_metrics = args.show_hidden_metrics_for_version is not None @@ -186,7 +189,7 @@ def __exit__(self, exc_type, exc_value, traceback): # force kill if needed self.proc.kill() - def _poll(self) -> Optional[int]: + def _poll(self) -> int | None: """Subclasses override this method to customize process polling""" return self.proc.poll() @@ -251,7 +254,7 @@ class RemoteOpenAIServerCustom(RemoteOpenAIServer): """Launch test server with custom child process""" def _start_server( - self, model: str, vllm_serve_args: list[str], env_dict: Optional[dict[str, str]] + self, model: str, vllm_serve_args: list[str], env_dict: dict[str, str] | None ) -> None: self.proc: Process = Process( target=self.child_process_fxn, args=(env_dict, model, vllm_serve_args) @@ -262,12 +265,12 @@ def __init__( self, model: str, vllm_serve_args: list[str], - child_process_fxn: Callable[[Optional[dict[str, str]], str, list[str]], None], + child_process_fxn: Callable[[dict[str, str] | None, str, list[str]], None], *, - env_dict: Optional[dict[str, str]] = None, - seed: Optional[int] = 0, + env_dict: dict[str, str] | None = None, + seed: int | None = 0, auto_port: bool = True, - max_wait_seconds: Optional[float] = None, + max_wait_seconds: float | None = None, ) -> None: """Store custom child process function then invoke superclass constructor which will indirectly launch it.""" @@ -281,7 +284,7 @@ def __init__( max_wait_seconds=max_wait_seconds, ) - def _poll(self) -> Optional[int]: + def _poll(self) -> int | None: return self.proc.exitcode def __exit__(self, exc_type, exc_value, traceback): @@ -292,6 +295,131 @@ def __exit__(self, exc_type, exc_value, traceback): self.proc.kill() +class RemoteAnthropicServer: + DUMMY_API_KEY = "token-abc123" # vLLM's Anthropic server does not need API key + + def __init__( + self, + model: str, + vllm_serve_args: list[str], + *, + env_dict: dict[str, str] | None = None, + seed: int | None = 0, + auto_port: bool = True, + max_wait_seconds: float | None = None, + ) -> None: + if auto_port: + if "-p" in vllm_serve_args or "--port" in vllm_serve_args: + raise ValueError( + "You have manually specified the port when `auto_port=True`." + ) + + # Don't mutate the input args + vllm_serve_args = vllm_serve_args + ["--port", str(get_open_port())] + if seed is not None: + if "--seed" in vllm_serve_args: + raise ValueError( + f"You have manually specified the seed when `seed={seed}`." + ) + + vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] + + parser = FlexibleArgumentParser(description="vLLM's remote Anthropic server.") + subparsers = parser.add_subparsers(required=False, dest="subparser") + parser = ServeSubcommand().subparser_init(subparsers) + args = parser.parse_args(["--model", model, *vllm_serve_args]) + self.host = str(args.host or "localhost") + self.port = int(args.port) + + self.show_hidden_metrics = args.show_hidden_metrics_for_version is not None + + # download the model before starting the server to avoid timeout + is_local = os.path.isdir(model) + if not is_local: + engine_args = AsyncEngineArgs.from_cli_args(args) + model_config = engine_args.create_model_config() + load_config = engine_args.create_load_config() + + model_loader = get_model_loader(load_config) + model_loader.download_model(model_config) + + env = os.environ.copy() + # the current process might initialize cuda, + # to be safe, we should use spawn method + env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + if env_dict is not None: + env.update(env_dict) + self.proc = subprocess.Popen( + [ + sys.executable, + "-m", + "vllm.entrypoints.anthropic.api_server", + model, + *vllm_serve_args, + ], + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + max_wait_seconds = max_wait_seconds or 240 + self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.proc.terminate() + try: + self.proc.wait(8) + except subprocess.TimeoutExpired: + # force kill if needed + self.proc.kill() + + def _wait_for_server(self, *, url: str, timeout: float): + # run health check + start = time.time() + while True: + try: + if requests.get(url).status_code == 200: + break + except Exception: + # this exception can only be raised by requests.get, + # which means the server is not ready yet. + # the stack trace is not useful, so we suppress it + # by using `raise from None`. + result = self.proc.poll() + if result is not None and result != 0: + raise RuntimeError("Server exited unexpectedly.") from None + + time.sleep(0.5) + if time.time() - start > timeout: + raise RuntimeError("Server failed to start in time.") from None + + @property + def url_root(self) -> str: + return f"http://{self.host}:{self.port}" + + def url_for(self, *parts: str) -> str: + return self.url_root + "/" + "/".join(parts) + + def get_client(self, **kwargs): + if "timeout" not in kwargs: + kwargs["timeout"] = 600 + return anthropic.Anthropic( + base_url=self.url_for(), + api_key=self.DUMMY_API_KEY, + max_retries=0, + **kwargs, + ) + + def get_async_client(self, **kwargs): + if "timeout" not in kwargs: + kwargs["timeout"] = 600 + return anthropic.AsyncAnthropic( + base_url=self.url_for(), api_key=self.DUMMY_API_KEY, max_retries=0, **kwargs + ) + + def _test_completion( client: openai.OpenAI, model: str, @@ -547,11 +675,11 @@ def compare_two_settings( model: str, arg1: list[str], arg2: list[str], - env1: Optional[dict[str, str]] = None, - env2: Optional[dict[str, str]] = None, + env1: dict[str, str] | None = None, + env2: dict[str, str] | None = None, *, method: str = "generate", - max_wait_seconds: Optional[float] = None, + max_wait_seconds: float | None = None, ) -> None: """ Launch API server with two different sets of arguments/environments @@ -577,10 +705,10 @@ def compare_two_settings( def compare_all_settings( model: str, all_args: list[list[str]], - all_envs: list[Optional[dict[str, str]]], + all_envs: list[dict[str, str] | None], *, method: str = "generate", - max_wait_seconds: Optional[float] = None, + max_wait_seconds: float | None = None, ) -> None: """ Launch API server with several different sets of arguments/environments @@ -785,8 +913,8 @@ def get_physical_device_indices(devices): def wait_for_gpu_memory_to_clear( *, devices: list[int], - threshold_bytes: Optional[int] = None, - threshold_ratio: Optional[float] = None, + threshold_bytes: int | None = None, + threshold_ratio: float | None = None, timeout_s: float = 120, ) -> None: assert threshold_bytes is not None or threshold_ratio is not None @@ -983,6 +1111,11 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: # `cloudpickle` allows pickling complex functions directly input_bytes = cloudpickle.dumps((f, output_filepath)) + repo_root = str(VLLM_PATH.resolve()) + + env = dict(env or os.environ) + env["PYTHONPATH"] = repo_root + os.pathsep + env.get("PYTHONPATH", "") + cmd = [sys.executable, "-m", f"{module_name}"] returned = subprocess.run( @@ -1002,7 +1135,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: def create_new_process_for_each_test( - method: Optional[Literal["spawn", "fork"]] = None, + method: Literal["spawn", "fork"] | None = None, ) -> Callable[[Callable[_P, None]], Callable[_P, None]]: """Creates a decorator that runs each test function in a new process. @@ -1098,9 +1231,9 @@ async def completions_with_server_args( prompts: list[str], model_name: str, server_cli_args: list[str], - num_logprobs: Optional[int], + num_logprobs: int | None, max_wait_seconds: int = 240, - max_tokens: Union[int, list] = 5, + max_tokens: int | list = 5, ) -> list[Completion]: """Construct a remote OpenAI server, obtain an async client to the server & invoke the completions API to obtain completions. @@ -1260,3 +1393,23 @@ def check_answers( frac_ok = numok / len(answer) print(f"Num OK: {numok}/{len(answer)} {frac_ok}") assert frac_ok >= accept_rate + + +def flat_product(*iterables: Iterable[Any]): + """ + Flatten lists of tuples of the cartesian product. + Useful when we want to avoid nested tuples to allow + test params to be unpacked directly from the decorator. + + Example: + flat_product([(1, 2), (3, 4)], ["a", "b"]) -> + [ + (1, 2, "a"), + (1, 2, "b"), + (3, 4, "a"), + (3, 4, "b"), + ] + """ + for element in itertools.product(*iterables): + normalized = (e if isinstance(e, tuple) else (e,) for e in element) + yield tuple(itertools.chain(*normalized)) diff --git a/tests/utils_/test_async_utils.py b/tests/utils_/test_async_utils.py new file mode 100644 index 000000000000..03d116bdfd81 --- /dev/null +++ b/tests/utils_/test_async_utils.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +from collections.abc import AsyncIterator + +import pytest + +from vllm.utils.async_utils import merge_async_iterators + + +async def _mock_async_iterator(idx: int): + try: + while True: + yield f"item from iterator {idx}" + await asyncio.sleep(0.1) + except asyncio.CancelledError: + print(f"iterator {idx} cancelled") + + +@pytest.mark.asyncio +async def test_merge_async_iterators(): + iterators = [_mock_async_iterator(i) for i in range(3)] + merged_iterator = merge_async_iterators(*iterators) + + async def stream_output(generator: AsyncIterator[tuple[int, str]]): + async for idx, output in generator: + print(f"idx: {idx}, output: {output}") + + task = asyncio.create_task(stream_output(merged_iterator)) + await asyncio.sleep(0.5) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + for iterator in iterators: + try: + await asyncio.wait_for(anext(iterator), 1) + except StopAsyncIteration: + # All iterators should be cancelled and print this message. + print("Iterator was cancelled normally") + except (Exception, asyncio.CancelledError) as e: + raise AssertionError() from e diff --git a/tests/utils_/test_collection_utils.py b/tests/utils_/test_collection_utils.py new file mode 100644 index 000000000000..19f4a3d1c95f --- /dev/null +++ b/tests/utils_/test_collection_utils.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from vllm.utils.collection_utils import swap_dict_values + + +@pytest.mark.parametrize( + "obj,key1,key2", + [ + # Tests for both keys exist + ({1: "a", 2: "b"}, 1, 2), + # Tests for one key does not exist + ({1: "a", 2: "b"}, 1, 3), + # Tests for both keys do not exist + ({1: "a", 2: "b"}, 3, 4), + ], +) +def test_swap_dict_values(obj, key1, key2): + original_obj = obj.copy() + + swap_dict_values(obj, key1, key2) + + if key1 in original_obj: + assert obj[key2] == original_obj[key1] + else: + assert key2 not in obj + if key2 in original_obj: + assert obj[key1] == original_obj[key2] + else: + assert key1 not in obj diff --git a/tests/utils_/test_func_utils.py b/tests/utils_/test_func_utils.py new file mode 100644 index 000000000000..9ce1ada095f1 --- /dev/null +++ b/tests/utils_/test_func_utils.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa + +import pytest + +from vllm.utils.func_utils import deprecate_kwargs, supports_kw + +from ..utils import error_on_warning + + +def test_deprecate_kwargs_always(): + @deprecate_kwargs("old_arg", is_deprecated=True) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="'old_arg'"): + dummy(old_arg=1) + + with error_on_warning(DeprecationWarning): + dummy(new_arg=1) + + +def test_deprecate_kwargs_never(): + @deprecate_kwargs("old_arg", is_deprecated=False) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with error_on_warning(DeprecationWarning): + dummy(old_arg=1) + + with error_on_warning(DeprecationWarning): + dummy(new_arg=1) + + +def test_deprecate_kwargs_dynamic(): + is_deprecated = True + + @deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="'old_arg'"): + dummy(old_arg=1) + + with error_on_warning(DeprecationWarning): + dummy(new_arg=1) + + is_deprecated = False + + with error_on_warning(DeprecationWarning): + dummy(old_arg=1) + + with error_on_warning(DeprecationWarning): + dummy(new_arg=1) + + +def test_deprecate_kwargs_additional_message(): + @deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd") + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="abcd"): + dummy(old_arg=1) + + +@pytest.mark.parametrize( + ("callable", "kw_name", "requires_kw_only", "allow_var_kwargs", "is_supported"), + [ + # Tests for positional argument support + (lambda foo: None, "foo", True, True, False), + (lambda foo: None, "foo", False, True, True), + # Tests for positional or keyword / keyword only + (lambda foo=100: None, "foo", True, True, False), + (lambda *, foo: None, "foo", False, True, True), + # Tests to make sure the names of variadic params are NOT supported + (lambda *args: None, "args", False, True, False), + (lambda **kwargs: None, "kwargs", False, True, False), + # Tests for if we allow var kwargs to add support + (lambda foo: None, "something_else", False, True, False), + (lambda foo, **kwargs: None, "something_else", False, True, True), + (lambda foo, **kwargs: None, "kwargs", True, True, False), + (lambda foo, **kwargs: None, "foo", True, True, False), + ], +) +def test_supports_kw( + callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported +): + assert ( + supports_kw( + callable=callable, + kw_name=kw_name, + requires_kw_only=requires_kw_only, + allow_var_kwargs=allow_var_kwargs, + ) + == is_supported + ) diff --git a/tests/utils_/test_hashing.py b/tests/utils_/test_hashing.py new file mode 100644 index 000000000000..484627a547d0 --- /dev/null +++ b/tests/utils_/test_hashing.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import hashlib +import pickle + +import pytest + +from vllm.utils.hashing import sha256 + + +@pytest.mark.parametrize("input", [(), ("abc",), (None,), (None, bool, [1, 2, 3])]) +def test_sha256(input: tuple): + digest = sha256(input) + assert digest is not None + assert isinstance(digest, bytes) + assert digest != b"" + + input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) + assert digest == hashlib.sha256(input_bytes).digest() + + # hashing again, returns the same value + assert digest == sha256(input) + + # hashing different input, returns different value + assert digest != sha256(input + (1,)) diff --git a/tests/utils_/test_import_utils.py b/tests/utils_/test_import_utils.py new file mode 100644 index 000000000000..d42685b3fc9a --- /dev/null +++ b/tests/utils_/test_import_utils.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from vllm.utils.import_utils import PlaceholderModule + + +def _raises_module_not_found(): + return pytest.raises(ModuleNotFoundError, match="No module named") + + +def test_placeholder_module_error_handling(): + placeholder = PlaceholderModule("placeholder_1234") + + with _raises_module_not_found(): + int(placeholder) + + with _raises_module_not_found(): + placeholder() + + with _raises_module_not_found(): + _ = placeholder.some_attr + + with _raises_module_not_found(): + # Test conflict with internal __name attribute + _ = placeholder.name + + # OK to print the placeholder or use it in a f-string + _ = repr(placeholder) + _ = str(placeholder) + + # No error yet; only error when it is used downstream + placeholder_attr = placeholder.placeholder_attr("attr") + + with _raises_module_not_found(): + int(placeholder_attr) + + with _raises_module_not_found(): + placeholder_attr() + + with _raises_module_not_found(): + _ = placeholder_attr.some_attr + + with _raises_module_not_found(): + # Test conflict with internal __module attribute + _ = placeholder_attr.module diff --git a/tests/utils_/test_jsontree.py b/tests/utils_/test_jsontree.py new file mode 100644 index 000000000000..0af2751b2638 --- /dev/null +++ b/tests/utils_/test_jsontree.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.utils.jsontree import json_count_leaves + + +def test_json_count_leaves(): + """Test json_count_leaves function from jsontree utility.""" + + # Single leaf values + assert json_count_leaves(42) == 1 + assert json_count_leaves("hello") == 1 + assert json_count_leaves(None) == 1 + + # Empty containers + assert json_count_leaves([]) == 0 + assert json_count_leaves({}) == 0 + assert json_count_leaves(()) == 0 + + # Flat structures + assert json_count_leaves([1, 2, 3]) == 3 + assert json_count_leaves({"a": 1, "b": 2}) == 2 + assert json_count_leaves((1, 2, 3)) == 3 + + # Nested structures + nested_dict = {"a": 1, "b": {"c": 2, "d": 3}} + assert json_count_leaves(nested_dict) == 3 + + nested_list = [1, [2, 3], 4] + assert json_count_leaves(nested_list) == 4 + + mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4} + assert json_count_leaves(mixed_nested) == 4 diff --git a/tests/utils_/test_mem_utils.py b/tests/utils_/test_mem_utils.py new file mode 100644 index 000000000000..4b1058be412d --- /dev/null +++ b/tests/utils_/test_mem_utils.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from vllm_test_utils.monitor import monitor + +from vllm.utils.mem_utils import MemorySnapshot, memory_profiling + +from ..utils import create_new_process_for_each_test + + +@create_new_process_for_each_test() +def test_memory_profiling(): + # Fake out some model loading + inference memory usage to test profiling + # Memory used by other processes will show up as cuda usage outside of torch + from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary + + lib = CudaRTLibrary() + # 512 MiB allocation outside of this instance + handle1 = lib.cudaMalloc(512 * 1024 * 1024) + + baseline_snapshot = MemorySnapshot() + + # load weights + + weights = torch.randn(128, 1024, 1024, device="cuda", dtype=torch.float32) + + weights_memory = 128 * 1024 * 1024 * 4 # 512 MiB + + def measure_current_non_torch(): + free, total = torch.cuda.mem_get_info() + current_used = total - free + current_torch = torch.cuda.memory_reserved() + current_non_torch = current_used - current_torch + return current_non_torch + + with ( + memory_profiling( + baseline_snapshot=baseline_snapshot, weights_memory=weights_memory + ) as result, + monitor(measure_current_non_torch) as monitored_values, + ): + # make a memory spike, 1 GiB + spike = torch.randn(256, 1024, 1024, device="cuda", dtype=torch.float32) + del spike + + # Add some extra non-torch memory 256 MiB (simulate NCCL) + handle2 = lib.cudaMalloc(256 * 1024 * 1024) + + # this is an analytic value, it is exact, + # we only have 256 MiB non-torch memory increase + measured_diff = monitored_values.values[-1] - monitored_values.values[0] + assert measured_diff == 256 * 1024 * 1024 + + # Check that the memory usage is within 5% of the expected values + # 5% tolerance is caused by cuda runtime. + # we cannot control cuda runtime in the granularity of bytes, + # which causes a small error (<10 MiB in practice) + non_torch_ratio = result.non_torch_increase / (256 * 1024 * 1024) # noqa + assert abs(non_torch_ratio - 1) <= 0.05 + assert result.torch_peak_increase == 1024 * 1024 * 1024 + del weights + lib.cudaFree(handle1) + lib.cudaFree(handle2) diff --git a/tests/utils_/test_network_utils.py b/tests/utils_/test_network_utils.py new file mode 100644 index 000000000000..bc274f0679b8 --- /dev/null +++ b/tests/utils_/test_network_utils.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import socket + +import pytest +import zmq + +from vllm.utils.network_utils import ( + get_open_port, + get_tcp_uri, + join_host_port, + make_zmq_path, + make_zmq_socket, + split_host_port, + split_zmq_path, +) + + +def test_get_open_port(monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m: + m.setenv("VLLM_PORT", "5678") + # make sure we can get multiple ports, even if the env var is set + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1: + s1.bind(("localhost", get_open_port())) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2: + s2.bind(("localhost", get_open_port())) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3: + s3.bind(("localhost", get_open_port())) + + +@pytest.mark.parametrize( + "path,expected", + [ + ("ipc://some_path", ("ipc", "some_path", "")), + ("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")), + ("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address + ("inproc://some_identifier", ("inproc", "some_identifier", "")), + ], +) +def test_split_zmq_path(path, expected): + assert split_zmq_path(path) == expected + + +@pytest.mark.parametrize( + "invalid_path", + [ + "invalid_path", # Missing scheme + "tcp://127.0.0.1", # Missing port + "tcp://[::1]", # Missing port for IPv6 + "tcp://:5555", # Missing host + ], +) +def test_split_zmq_path_invalid(invalid_path): + with pytest.raises(ValueError): + split_zmq_path(invalid_path) + + +def test_make_zmq_socket_ipv6(): + # Check if IPv6 is supported by trying to create an IPv6 socket + try: + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + sock.close() + except OSError: + pytest.skip("IPv6 is not supported on this system") + + ctx = zmq.Context() + ipv6_path = "tcp://[::]:5555" # IPv6 loopback address + socket_type = zmq.REP # Example socket type + + # Create the socket + zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type) + + # Verify that the IPV6 option is set + assert zsock.getsockopt(zmq.IPV6) == 1, ( + "IPV6 option should be enabled for IPv6 addresses" + ) + + # Clean up + zsock.close() + ctx.term() + + +def test_make_zmq_path(): + assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555" + assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555" + + +def test_get_tcp_uri(): + assert get_tcp_uri("127.0.0.1", 5555) == "tcp://127.0.0.1:5555" + assert get_tcp_uri("::1", 5555) == "tcp://[::1]:5555" + + +def test_split_host_port(): + # valid ipv4 + assert split_host_port("127.0.0.1:5555") == ("127.0.0.1", 5555) + # invalid ipv4 + with pytest.raises(ValueError): + # multi colon + assert split_host_port("127.0.0.1::5555") + with pytest.raises(ValueError): + # tailing colon + assert split_host_port("127.0.0.1:5555:") + with pytest.raises(ValueError): + # no colon + assert split_host_port("127.0.0.15555") + with pytest.raises(ValueError): + # none int port + assert split_host_port("127.0.0.1:5555a") + + # valid ipv6 + assert split_host_port("[::1]:5555") == ("::1", 5555) + # invalid ipv6 + with pytest.raises(ValueError): + # multi colon + assert split_host_port("[::1]::5555") + with pytest.raises(IndexError): + # no colon + assert split_host_port("[::1]5555") + with pytest.raises(ValueError): + # none int port + assert split_host_port("[::1]:5555a") + + +def test_join_host_port(): + assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555" + assert join_host_port("::1", 5555) == "[::1]:5555" diff --git a/tests/utils_/test_serial_utils.py b/tests/utils_/test_serial_utils.py new file mode 100644 index 000000000000..7f2c1bdacf90 --- /dev/null +++ b/tests/utils_/test_serial_utils.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from tests.models.utils import check_embeddings_close +from vllm.utils.serial_utils import ( + EMBED_DTYPE_TO_TORCH_DTYPE, + ENDIANNESS, + binary2tensor, + tensor2binary, +) + + +@pytest.mark.parametrize("endianness", ENDIANNESS) +@pytest.mark.parametrize("embed_dtype", EMBED_DTYPE_TO_TORCH_DTYPE.keys()) +@torch.inference_mode +def test_encode_and_decode(embed_dtype: str, endianness: str): + for i in range(10): + tensor = torch.rand(2, 3, 5, 7, 11, 13, device="cpu", dtype=torch.float32) + shape = tensor.shape + binary = tensor2binary(tensor, embed_dtype, endianness) + new_tensor = binary2tensor(binary, shape, embed_dtype, endianness).to( + torch.float32 + ) + + if embed_dtype in ["float32", "float16"]: + torch.testing.assert_close(tensor, new_tensor, atol=0.001, rtol=0.001) + elif embed_dtype == "bfloat16": + torch.testing.assert_close(tensor, new_tensor, atol=0.01, rtol=0.01) + else: # for fp8 + torch.testing.assert_close(tensor, new_tensor, atol=0.1, rtol=0.1) + + check_embeddings_close( + embeddings_0_lst=tensor.view(1, -1), + embeddings_1_lst=new_tensor.view(1, -1), + name_0="gt", + name_1="new", + tol=1e-2, + ) diff --git a/tests/utils_/test_system_utils.py b/tests/utils_/test_system_utils.py new file mode 100644 index 000000000000..3d1b1fc4ce37 --- /dev/null +++ b/tests/utils_/test_system_utils.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import tempfile +from pathlib import Path + +from vllm.utils.system_utils import unique_filepath + + +def test_unique_filepath(): + temp_dir = tempfile.mkdtemp() + path_fn = lambda i: Path(temp_dir) / f"file_{i}.txt" + paths = set() + for i in range(10): + path = unique_filepath(path_fn) + path.write_text("test") + paths.add(path) + assert len(paths) == 10 + assert len(list(Path(temp_dir).glob("*.txt"))) == 10 diff --git a/tests/utils_/test_torch_utils.py b/tests/utils_/test_torch_utils.py new file mode 100644 index 000000000000..0a30b9727f4d --- /dev/null +++ b/tests/utils_/test_torch_utils.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from vllm.utils.torch_utils import ( + common_broadcastable_dtype, + current_stream, + is_lossless_cast, +) + + +@pytest.mark.parametrize( + ("src_dtype", "tgt_dtype", "expected_result"), + [ + # Different precision_levels + (torch.bool, torch.int8, True), + (torch.bool, torch.float16, True), + (torch.bool, torch.complex32, True), + (torch.int64, torch.bool, False), + (torch.int64, torch.float16, True), + (torch.int64, torch.complex32, True), + (torch.float64, torch.bool, False), + (torch.float64, torch.int8, False), + (torch.float64, torch.complex32, True), + (torch.complex128, torch.bool, False), + (torch.complex128, torch.int8, False), + (torch.complex128, torch.float16, False), + # precision_level=0 + (torch.bool, torch.bool, True), + # precision_level=1 + (torch.int8, torch.int16, True), + (torch.int16, torch.int8, False), + (torch.uint8, torch.int8, False), + (torch.int8, torch.uint8, False), + # precision_level=2 + (torch.float16, torch.float32, True), + (torch.float32, torch.float16, False), + (torch.bfloat16, torch.float32, True), + (torch.float32, torch.bfloat16, False), + # precision_level=3 + (torch.complex32, torch.complex64, True), + (torch.complex64, torch.complex32, False), + ], +) +def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result): + assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result + + +@pytest.mark.parametrize( + ("dtypes", "expected_result"), + [ + ([torch.bool], torch.bool), + ([torch.bool, torch.int8], torch.int8), + ([torch.bool, torch.int8, torch.float16], torch.float16), + ([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32), # noqa: E501 + ], +) +def test_common_broadcastable_dtype(dtypes, expected_result): + assert common_broadcastable_dtype(dtypes) == expected_result + + +def _test_stream_thread(main_expected_stream: torch.cuda.Stream): + import threading + + child_stream = torch.cuda.Stream() + thread_stream_ready = threading.Event() + thread_can_exit = threading.Event() + + def child_thread_func(): + with torch.cuda.stream(child_stream): + thread_stream_ready.set() + thread_can_exit.wait(timeout=10) + + child_thread = threading.Thread(target=child_thread_func) + child_thread.start() + + try: + assert thread_stream_ready.wait(timeout=5), ( + "Child thread failed to enter stream context in time" + ) + + main_current_stream = current_stream() + + assert main_current_stream != child_stream, ( + "Main thread's current_stream was contaminated by child thread" + ) + assert main_current_stream == main_expected_stream, ( + f"Main thread's stream changed unexpectedly. " + f"Expected {main_expected_stream}, got {main_current_stream}" + ) + + thread_can_exit.set() + + finally: + child_thread.join(timeout=5) + if child_thread.is_alive(): + pytest.fail("Child thread failed to exit properly") + + +def test_current_stream_multithread(): + from vllm.platforms import current_platform + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + if current_platform.is_rocm(): + main_dedicated_stream = current_stream() + + assert main_dedicated_stream.cuda_stream != 0, ( + "ROCm should create a dedicated stream, not use default stream (0x0)" + ) + + main_stream_again = current_stream() + assert main_stream_again == main_dedicated_stream, ( + "Multiple calls to current_stream should return the same dedicated stream" + ) + + _test_stream_thread(main_dedicated_stream) + else: + main_default_stream = torch.cuda.default_stream() + main_initial_stream = current_stream() + + assert main_initial_stream == main_default_stream, ( + "First call to current_stream should return default stream on CUDA" + ) + + _test_stream_thread(main_default_stream) diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index cd5fa550498b..08dc7632b74b 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -2,153 +2,25 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # ruff: noqa -import asyncio -import hashlib import json import os -import pickle -import socket import tempfile -from collections.abc import AsyncIterator from pathlib import Path from unittest.mock import patch import pytest import torch import yaml -import zmq from transformers import AutoTokenizer -from vllm_test_utils.monitor import monitor from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens from vllm.utils import ( FlexibleArgumentParser, - MemorySnapshot, - PlaceholderModule, bind_kv_cache, - common_broadcastable_dtype, - current_stream, - deprecate_kwargs, - get_open_port, - get_tcp_uri, - is_lossless_cast, - join_host_port, - make_zmq_path, - make_zmq_socket, - memory_profiling, - merge_async_iterators, - sha256, - split_host_port, - split_zmq_path, - supports_kw, - swap_dict_values, - unique_filepath, ) - -from ..utils import create_new_process_for_each_test, error_on_warning - - -@pytest.mark.asyncio -async def test_merge_async_iterators(): - async def mock_async_iterator(idx: int): - try: - while True: - yield f"item from iterator {idx}" - await asyncio.sleep(0.1) - except asyncio.CancelledError: - print(f"iterator {idx} cancelled") - - iterators = [mock_async_iterator(i) for i in range(3)] - merged_iterator = merge_async_iterators(*iterators) - - async def stream_output(generator: AsyncIterator[tuple[int, str]]): - async for idx, output in generator: - print(f"idx: {idx}, output: {output}") - - task = asyncio.create_task(stream_output(merged_iterator)) - await asyncio.sleep(0.5) - task.cancel() - with pytest.raises(asyncio.CancelledError): - await task - - for iterator in iterators: - try: - # Can use anext() in python >= 3.10 - await asyncio.wait_for(iterator.__anext__(), 1) - except StopAsyncIteration: - # All iterators should be cancelled and print this message. - print("Iterator was cancelled normally") - except (Exception, asyncio.CancelledError) as e: - raise AssertionError() from e - - -def test_deprecate_kwargs_always(): - @deprecate_kwargs("old_arg", is_deprecated=True) - def dummy(*, old_arg: object = None, new_arg: object = None): - pass - - with pytest.warns(DeprecationWarning, match="'old_arg'"): - dummy(old_arg=1) - - with error_on_warning(DeprecationWarning): - dummy(new_arg=1) - - -def test_deprecate_kwargs_never(): - @deprecate_kwargs("old_arg", is_deprecated=False) - def dummy(*, old_arg: object = None, new_arg: object = None): - pass - - with error_on_warning(DeprecationWarning): - dummy(old_arg=1) - - with error_on_warning(DeprecationWarning): - dummy(new_arg=1) - - -def test_deprecate_kwargs_dynamic(): - is_deprecated = True - - @deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated) - def dummy(*, old_arg: object = None, new_arg: object = None): - pass - - with pytest.warns(DeprecationWarning, match="'old_arg'"): - dummy(old_arg=1) - - with error_on_warning(DeprecationWarning): - dummy(new_arg=1) - - is_deprecated = False - - with error_on_warning(DeprecationWarning): - dummy(old_arg=1) - - with error_on_warning(DeprecationWarning): - dummy(new_arg=1) - - -def test_deprecate_kwargs_additional_message(): - @deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd") - def dummy(*, old_arg: object = None, new_arg: object = None): - pass - - with pytest.warns(DeprecationWarning, match="abcd"): - dummy(old_arg=1) - - -def test_get_open_port(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m: - m.setenv("VLLM_PORT", "5678") - # make sure we can get multiple ports, even if the env var is set - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1: - s1.bind(("localhost", get_open_port())) - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2: - s2.bind(("localhost", get_open_port())) - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3: - s3.bind(("localhost", get_open_port())) +from ..utils import create_new_process_for_each_test, flat_product # Tests for FlexibleArgumentParser @@ -300,7 +172,7 @@ def test_dict_args(parser): "val2", "--hf-overrides.key2.key4", "val3", - # Test compile config and compilation level + # Test compile config and compilation mode "-O.use_inductor=true", "-O.backend", "custom", @@ -353,7 +225,7 @@ def test_dict_args(parser): }, } assert parsed_args.compilation_config == { - "level": 1, + "mode": 1, "use_inductor": True, "backend": "custom", "custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"], @@ -368,7 +240,7 @@ def test_duplicate_dict_args(caplog_vllm, parser): "--hf-overrides.key1", "val2", "-O1", - "-O.level", + "-O.mode", "2", "-O3", ] @@ -376,100 +248,12 @@ def test_duplicate_dict_args(caplog_vllm, parser): parsed_args = parser.parse_args(args) # Should be the last value assert parsed_args.hf_overrides == {"key1": "val2"} - assert parsed_args.compilation_config == {"level": 3} + assert parsed_args.compilation_config == {"mode": 3} assert len(caplog_vllm.records) == 1 assert "duplicate" in caplog_vllm.text assert "--hf-overrides.key1" in caplog_vllm.text - assert "-O.level" in caplog_vllm.text - - -@pytest.mark.parametrize( - "callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported", - [ - # Tests for positional argument support - (lambda foo: None, "foo", True, True, False), - (lambda foo: None, "foo", False, True, True), - # Tests for positional or keyword / keyword only - (lambda foo=100: None, "foo", True, True, False), - (lambda *, foo: None, "foo", False, True, True), - # Tests to make sure the names of variadic params are NOT supported - (lambda *args: None, "args", False, True, False), - (lambda **kwargs: None, "kwargs", False, True, False), - # Tests for if we allow var kwargs to add support - (lambda foo: None, "something_else", False, True, False), - (lambda foo, **kwargs: None, "something_else", False, True, True), - (lambda foo, **kwargs: None, "kwargs", True, True, False), - (lambda foo, **kwargs: None, "foo", True, True, False), - ], -) -def test_supports_kw( - callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported -): - assert ( - supports_kw( - callable=callable, - kw_name=kw_name, - requires_kw_only=requires_kw_only, - allow_var_kwargs=allow_var_kwargs, - ) - == is_supported - ) - - -@create_new_process_for_each_test() -def test_memory_profiling(): - # Fake out some model loading + inference memory usage to test profiling - # Memory used by other processes will show up as cuda usage outside of torch - from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary - - lib = CudaRTLibrary() - # 512 MiB allocation outside of this instance - handle1 = lib.cudaMalloc(512 * 1024 * 1024) - - baseline_snapshot = MemorySnapshot() - - # load weights - - weights = torch.randn(128, 1024, 1024, device="cuda", dtype=torch.float32) - - weights_memory = 128 * 1024 * 1024 * 4 # 512 MiB - - def measure_current_non_torch(): - free, total = torch.cuda.mem_get_info() - current_used = total - free - current_torch = torch.cuda.memory_reserved() - current_non_torch = current_used - current_torch - return current_non_torch - - with ( - memory_profiling( - baseline_snapshot=baseline_snapshot, weights_memory=weights_memory - ) as result, - monitor(measure_current_non_torch) as monitored_values, - ): - # make a memory spike, 1 GiB - spike = torch.randn(256, 1024, 1024, device="cuda", dtype=torch.float32) - del spike - - # Add some extra non-torch memory 256 MiB (simulate NCCL) - handle2 = lib.cudaMalloc(256 * 1024 * 1024) - - # this is an analytic value, it is exact, - # we only have 256 MiB non-torch memory increase - measured_diff = monitored_values.values[-1] - monitored_values.values[0] - assert measured_diff == 256 * 1024 * 1024 - - # Check that the memory usage is within 5% of the expected values - # 5% tolerance is caused by cuda runtime. - # we cannot control cuda runtime in the granularity of bytes, - # which causes a small error (<10 MiB in practice) - non_torch_ratio = result.non_torch_increase / (256 * 1024 * 1024) # noqa - assert abs(non_torch_ratio - 1) <= 0.05 - assert result.torch_peak_increase == 1024 * 1024 * 1024 - del weights - lib.cudaFree(handle1) - lib.cudaFree(handle2) + assert "-O.mode" in caplog_vllm.text def test_bind_kv_cache(): @@ -538,7 +322,7 @@ def test_bind_kv_cache_non_attention(): def test_bind_kv_cache_pp(): - with patch("vllm.utils.cuda_device_count_stateless", lambda: 2): + with patch("vllm.utils.torch_utils.cuda_device_count_stateless", lambda: 2): # this test runs with 1 GPU, but we simulate 2 GPUs cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2)) with set_current_vllm_config(cfg): @@ -553,120 +337,6 @@ def test_bind_kv_cache_pp(): assert ctx["layers.0.self_attn"].kv_cache[1] is kv_cache[1][0] -@pytest.mark.parametrize( - ("src_dtype", "tgt_dtype", "expected_result"), - [ - # Different precision_levels - (torch.bool, torch.int8, True), - (torch.bool, torch.float16, True), - (torch.bool, torch.complex32, True), - (torch.int64, torch.bool, False), - (torch.int64, torch.float16, True), - (torch.int64, torch.complex32, True), - (torch.float64, torch.bool, False), - (torch.float64, torch.int8, False), - (torch.float64, torch.complex32, True), - (torch.complex128, torch.bool, False), - (torch.complex128, torch.int8, False), - (torch.complex128, torch.float16, False), - # precision_level=0 - (torch.bool, torch.bool, True), - # precision_level=1 - (torch.int8, torch.int16, True), - (torch.int16, torch.int8, False), - (torch.uint8, torch.int8, False), - (torch.int8, torch.uint8, False), - # precision_level=2 - (torch.float16, torch.float32, True), - (torch.float32, torch.float16, False), - (torch.bfloat16, torch.float32, True), - (torch.float32, torch.bfloat16, False), - # precision_level=3 - (torch.complex32, torch.complex64, True), - (torch.complex64, torch.complex32, False), - ], -) -def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result): - assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result - - -@pytest.mark.parametrize( - ("dtypes", "expected_result"), - [ - ([torch.bool], torch.bool), - ([torch.bool, torch.int8], torch.int8), - ([torch.bool, torch.int8, torch.float16], torch.float16), - ([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32), # noqa: E501 - ], -) -def test_common_broadcastable_dtype(dtypes, expected_result): - assert common_broadcastable_dtype(dtypes) == expected_result - - -def test_placeholder_module_error_handling(): - placeholder = PlaceholderModule("placeholder_1234") - - def build_ctx(): - return pytest.raises(ModuleNotFoundError, match="No module named") - - with build_ctx(): - int(placeholder) - - with build_ctx(): - placeholder() - - with build_ctx(): - _ = placeholder.some_attr - - with build_ctx(): - # Test conflict with internal __name attribute - _ = placeholder.name - - # OK to print the placeholder or use it in a f-string - _ = repr(placeholder) - _ = str(placeholder) - - # No error yet; only error when it is used downstream - placeholder_attr = placeholder.placeholder_attr("attr") - - with build_ctx(): - int(placeholder_attr) - - with build_ctx(): - placeholder_attr() - - with build_ctx(): - _ = placeholder_attr.some_attr - - with build_ctx(): - # Test conflict with internal __module attribute - _ = placeholder_attr.module - - -@pytest.mark.parametrize( - "obj,key1,key2", - [ - # Tests for both keys exist - ({1: "a", 2: "b"}, 1, 2), - # Tests for one key does not exist - ({1: "a", 2: "b"}, 1, 3), - # Tests for both keys do not exist - ({1: "a", 2: "b"}, 3, 4), - ], -) -def test_swap_dict_values(obj, key1, key2): - original_obj = obj.copy() - swap_dict_values(obj, key1, key2) - if key1 in original_obj: - assert obj[key2] == original_obj[key1] - else: - assert key2 not in obj - if key2 in original_obj: - assert obj[key1] == original_obj[key2] - else: - assert key1 not in obj - - def test_model_specification( parser_with_config, cli_config_file, cli_config_file_with_model ): @@ -749,151 +419,6 @@ def test_model_specification( assert args.port == 12312 -@pytest.mark.parametrize("input", [(), ("abc",), (None,), (None, bool, [1, 2, 3])]) -def test_sha256(input: tuple): - digest = sha256(input) - assert digest is not None - assert isinstance(digest, bytes) - assert digest != b"" - - input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) - assert digest == hashlib.sha256(input_bytes).digest() - - # hashing again, returns the same value - assert digest == sha256(input) - - # hashing different input, returns different value - assert digest != sha256(input + (1,)) - - -@pytest.mark.parametrize( - "path,expected", - [ - ("ipc://some_path", ("ipc", "some_path", "")), - ("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")), - ("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address - ("inproc://some_identifier", ("inproc", "some_identifier", "")), - ], -) -def test_split_zmq_path(path, expected): - assert split_zmq_path(path) == expected - - -@pytest.mark.parametrize( - "invalid_path", - [ - "invalid_path", # Missing scheme - "tcp://127.0.0.1", # Missing port - "tcp://[::1]", # Missing port for IPv6 - "tcp://:5555", # Missing host - ], -) -def test_split_zmq_path_invalid(invalid_path): - with pytest.raises(ValueError): - split_zmq_path(invalid_path) - - -def test_make_zmq_socket_ipv6(): - # Check if IPv6 is supported by trying to create an IPv6 socket - try: - sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - sock.close() - except socket.error: - pytest.skip("IPv6 is not supported on this system") - - ctx = zmq.Context() - ipv6_path = "tcp://[::]:5555" # IPv6 loopback address - socket_type = zmq.REP # Example socket type - - # Create the socket - zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type) - - # Verify that the IPV6 option is set - assert zsock.getsockopt(zmq.IPV6) == 1, ( - "IPV6 option should be enabled for IPv6 addresses" - ) - - # Clean up - zsock.close() - ctx.term() - - -def test_make_zmq_path(): - assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555" - assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555" - - -def test_get_tcp_uri(): - assert get_tcp_uri("127.0.0.1", 5555) == "tcp://127.0.0.1:5555" - assert get_tcp_uri("::1", 5555) == "tcp://[::1]:5555" - - -def test_split_host_port(): - # valid ipv4 - assert split_host_port("127.0.0.1:5555") == ("127.0.0.1", 5555) - # invalid ipv4 - with pytest.raises(ValueError): - # multi colon - assert split_host_port("127.0.0.1::5555") - with pytest.raises(ValueError): - # tailing colon - assert split_host_port("127.0.0.1:5555:") - with pytest.raises(ValueError): - # no colon - assert split_host_port("127.0.0.15555") - with pytest.raises(ValueError): - # none int port - assert split_host_port("127.0.0.1:5555a") - - # valid ipv6 - assert split_host_port("[::1]:5555") == ("::1", 5555) - # invalid ipv6 - with pytest.raises(ValueError): - # multi colon - assert split_host_port("[::1]::5555") - with pytest.raises(IndexError): - # no colon - assert split_host_port("[::1]5555") - with pytest.raises(ValueError): - # none int port - assert split_host_port("[::1]:5555a") - - -def test_join_host_port(): - assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555" - assert join_host_port("::1", 5555) == "[::1]:5555" - - -def test_json_count_leaves(): - """Test json_count_leaves function from jsontree utility.""" - from vllm.utils.jsontree import json_count_leaves - - # Single leaf values - assert json_count_leaves(42) == 1 - assert json_count_leaves("hello") == 1 - assert json_count_leaves(None) == 1 - - # Empty containers - assert json_count_leaves([]) == 0 - assert json_count_leaves({}) == 0 - assert json_count_leaves(()) == 0 - - # Flat structures - assert json_count_leaves([1, 2, 3]) == 3 - assert json_count_leaves({"a": 1, "b": 2}) == 2 - assert json_count_leaves((1, 2, 3)) == 3 - - # Nested structures - nested_dict = {"a": 1, "b": {"c": 2, "d": 3}} - assert json_count_leaves(nested_dict) == 3 - - nested_list = [1, [2, 3], 4] - assert json_count_leaves(nested_list) == 4 - - mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4} - assert json_count_leaves(mixed_nested) == 4 - - def test_convert_ids_list_to_tokens(): tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") token_ids = tokenizer.encode("Hello, world!") @@ -903,50 +428,6 @@ def test_convert_ids_list_to_tokens(): assert tokens == ["Hello", ",", " world", "!"] -def test_current_stream_multithread(): - import threading - - if not torch.cuda.is_available(): - pytest.skip("CUDA not available") - - main_default_stream = torch.cuda.current_stream() - child_stream = torch.cuda.Stream() - - thread_stream_ready = threading.Event() - thread_can_exit = threading.Event() - - def child_thread_func(): - with torch.cuda.stream(child_stream): - thread_stream_ready.set() - thread_can_exit.wait(timeout=10) - - child_thread = threading.Thread(target=child_thread_func) - child_thread.start() - - try: - assert thread_stream_ready.wait(timeout=5), ( - "Child thread failed to enter stream context in time" - ) - - main_current_stream = current_stream() - - assert main_current_stream != child_stream, ( - "Main thread's current_stream was contaminated by child thread" - ) - assert main_current_stream == main_default_stream, ( - "Main thread's current_stream is not the default stream" - ) - - # Notify child thread it can exit - thread_can_exit.set() - - finally: - # Ensure child thread exits properly - child_thread.join(timeout=5) - if child_thread.is_alive(): - pytest.fail("Child thread failed to exit properly") - - def test_load_config_file(tmp_path): # Define the configuration data config_data = { @@ -984,13 +465,23 @@ def test_load_config_file(tmp_path): os.remove(str(config_file_path)) -def test_unique_filepath(): - temp_dir = tempfile.mkdtemp() - path_fn = lambda i: Path(temp_dir) / f"file_{i}.txt" - paths = set() - for i in range(10): - path = unique_filepath(path_fn) - path.write_text("test") - paths.add(path) - assert len(paths) == 10 - assert len(list(Path(temp_dir).glob("*.txt"))) == 10 +def test_flat_product(): + # Check regular itertools.product behavior + result1 = list(flat_product([1, 2, 3], ["a", "b"])) + assert result1 == [ + (1, "a"), + (1, "b"), + (2, "a"), + (2, "b"), + (3, "a"), + (3, "b"), + ] + + # check that the tuples get flattened + result2 = list(flat_product([(1, 2), (3, 4)], ["a", "b"], [(5, 6)])) + assert result2 == [ + (1, 2, "a", 5, 6), + (1, 2, "b", 5, 6), + (3, 4, "a", 5, 6), + (3, 4, "b", 5, 6), + ] diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 188482e071ee..12f7fc66d17b 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -3,7 +3,6 @@ """Tests for v1 attention backends without GPUModelRunner dependency.""" from functools import partial -from typing import Optional, Union import pytest import torch @@ -14,12 +13,13 @@ create_common_attn_metadata, create_standard_kv_cache_spec, create_vllm_config, - get_attention_backend, + try_get_attention_backend, ) from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer +from vllm.utils import cdiv +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, set_kv_cache_layout, @@ -202,7 +202,7 @@ def run_attention_backend( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - sliding_window: Optional[int] = None, + sliding_window: int | None = None, ) -> torch.Tensor: """Run attention computation using the specified backend's AttentionImpl.""" @@ -214,7 +214,7 @@ def run_attention_backend( actual_backend = _Backend.FLEX_ATTENTION use_direct_block_mask = False - builder_cls, impl_cls = get_attention_backend(actual_backend) + builder_cls, impl_cls = try_get_attention_backend(actual_backend) # Mock flashinfer's get_per_layer_parameters if needed if actual_backend == _Backend.FLASHINFER: @@ -289,7 +289,7 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): def _test_backend_correctness( batch_spec: BatchSpec, model: str, - backend_to_test: list[Union[_Backend, str]], + backend_to_test: list[_Backend | str], mask_mod, *, block_size: int = 16, @@ -424,13 +424,14 @@ def _test_backend_correctness( for backend_name in backend_to_test: # FlashAttentionm + FlexAttention: # [2, num_blocks, block_size, num_kv_heads, head_size] - # FlashInfer: + # FlashInfer + Triton: # [num_blocks, 2, block_size, num_kv_heads, head_size] # Select the appropriate KV cache format for each backend kv_cache_for_backend = kv_cache - if backend_name == _Backend.FLASHINFER: + if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN): kv_cache_for_backend = kv_cache.transpose(0, 1) + if backend_name == _Backend.FLASHINFER: # For FlashInfer default to HND layout and kv_cache_for_backend = ( kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index debaa6a5e009..81fd6433b0c8 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -1,8 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for v1 MLA backends without GPUModelRunner dependency.""" +"""Tests for v1 MLA backends without GPUModelRunner dependency. -from typing import Optional, Union +Known Issues: +- FLASH_ATTN_MLA backend occasionally produces NaN values in + test_backend_correctness[mixed_small] when run after + test_backend_correctness[small_prefill], but passes when run alone. +""" import pytest import torch @@ -12,11 +16,14 @@ create_common_attn_metadata, create_standard_kv_cache_spec, create_vllm_config, - get_attention_backend, + try_get_attention_backend, ) from vllm import _custom_ops as ops from vllm.attention.backends.registry import _Backend -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.attention.ops.flashmla import is_flashmla_dense_supported +from vllm.config.vllm import set_current_vllm_config +from vllm.utils import cdiv +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import FullAttentionSpec @@ -31,6 +38,10 @@ if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10: BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) +# Remove FLASHMLA from the list if not supported +if not is_flashmla_dense_supported()[0]: + BACKENDS_TO_TEST.remove(_Backend.FLASHMLA) + torch.manual_seed(42) @@ -68,6 +79,12 @@ def _convert_dtype_to_torch(dtype): "large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), "single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]), "single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]), + "spec_decode_small": BatchSpec( + seq_lens=[128, 256, 512, 1024], query_lens=[4, 4, 4, 4] + ), + "spec_decode_medium": BatchSpec( + seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[8, 8, 8, 8, 8, 8] + ), } @@ -81,8 +98,8 @@ def create_and_prepopulate_kv_cache( num_blocks: int, common_attn_metadata: CommonAttentionMetadata, randomize_blocks: bool = True, - kv_cache_dtype: Optional[str] = None, - scale: Union[float, torch.Tensor] = 1.0, + kv_cache_dtype: str | None = None, + scale: float | torch.Tensor = 1.0, ) -> torch.Tensor: """Create and prepopulate an MLA KV cache with context data. @@ -239,63 +256,66 @@ def run_attention_backend( ) -> torch.Tensor: """Run attention computation using the specified backend's AttentionImpl.""" - builder_cls, impl_cls = get_attention_backend(backend) + builder_cls, impl_cls = try_get_attention_backend(backend) - # Build metadata - builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) - attn_metadata = builder.build( - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - ) + # Set the current vllm config so that get_current_vllm_config() works + # in the backend implementations + with set_current_vllm_config(vllm_config): + # Build metadata + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) - # Instantiate MLA implementation - num_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config - ) - num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config - ) - head_size = vllm_config.model_config.get_head_size() - scale = 1.0 / (head_size**0.5) - impl = impl_cls( - num_heads=num_heads, - head_size=head_size, - scale=scale, - num_kv_heads=num_kv_heads, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype="auto", - logits_soft_cap=None, - attn_type="decoder", - kv_sharing_target_layer_name=None, - q_lora_rank=None, - kv_lora_rank=kv_lora_rank, - qk_nope_head_dim=qk_nope_head_dim, - qk_rope_head_dim=qk_rope_head_dim, - qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, - v_head_dim=v_head_dim, - kv_b_proj=mock_kv_b_proj, - ) + # Instantiate MLA implementation + num_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config + ) + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config + ) + head_size = vllm_config.model_config.get_head_size() + scale = 1.0 / (head_size**0.5) + impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + logits_soft_cap=None, + attn_type="decoder", + kv_sharing_target_layer_name=None, + q_lora_rank=None, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, + v_head_dim=v_head_dim, + kv_b_proj=mock_kv_b_proj, + ) - # Process weights to create W_UK_T and W_UV attributes needed by MLA - act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) - impl.process_weights_after_loading(act_dtype) + # Process weights to create W_UK_T and W_UV attributes needed by MLA + act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) + impl.process_weights_after_loading(act_dtype) - # Create mock layer and output buffer - mock_layer = MockAttentionLayer(device) - num_tokens = query.shape[0] - output = torch.empty( - num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device - ) + # Create mock layer and output buffer + mock_layer = MockAttentionLayer(device) + num_tokens = query.shape[0] + output = torch.empty( + num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device + ) - # Run forward pass - # NOTE: The query, key, and value are already shaped correctly - # in the calling test function. - output = impl.forward( - mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output - ) + # Run forward pass + # NOTE: The query, key, and value are already shaped correctly + # in the calling test function. + output = impl.forward( + mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output + ) - return output + return output @pytest.mark.parametrize( @@ -311,6 +331,8 @@ def run_attention_backend( "large_prefill", "single_decode", "single_prefill", + "spec_decode_small", + "spec_decode_medium", ], ) @pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"]) @@ -330,10 +352,39 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): simulated paged KV cache. 5. Comparing the vLLM backend's output to the ground-truth SDPA output. """ + from vllm.v1.attention.backends.mla.common import QueryLenSupport + batch_spec = BATCH_SPECS[batch_spec_name] + is_spec_decode_test = batch_spec_name.startswith("spec_decode") + spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA} + + block_size = 16 + required_blocks = sum( + (seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens + ) + # Add 1 for null block at index 0, and some buffer + num_gpu_blocks = required_blocks + 1 + 100 + vllm_config = create_vllm_config( - model_name=model, max_model_len=max(batch_spec.seq_lens), num_gpu_blocks=2048 + model_name=model, + max_model_len=max(batch_spec.seq_lens), + num_gpu_blocks=num_gpu_blocks, + block_size=block_size, ) + + # For spec decode tests, add a speculative_config to set the reorder_batch_threshold + if is_spec_decode_test: + from vllm.config import SpeculativeConfig + + # Get the query length from the batch spec (they should all be uniform) + query_len = batch_spec.query_lens[0] + # Set num_speculative_tokens to query_len - 1 + # (since threshold is 1 + num_spec_tokens) + # Use ngram method which doesn't require a draft model + vllm_config.speculative_config = SpeculativeConfig( + method="ngram", num_speculative_tokens=query_len - 1 + ) + device = torch.device("cuda:0") kv_cache_spec = create_standard_kv_cache_spec(vllm_config) @@ -397,11 +448,37 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): # K_PE (rope component): [s_len, 1, qk_rope_head_dim] k_pe_full = torch.randn(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device) - # Determine if this is decode or prefill + # Determine if this sequence uses the decode pipeline or prefill + # pipeline for each backend + # NOTE: For spec decode tests with uniform query_len > 1, backends that + # support spec decode (FLASH_ATTN_MLA with varlen support, FLASHMLA with + # uniform support) will use the decode pipeline (MQA-style), while + # backends that only support single-token queries will use the prefill + # pipeline (MHA-style). This ensures the reference implementation + # matches each backend's actual decode/prefill pipeline path. is_decode = [] - for i, backend in enumerate(BACKENDS_TO_TEST): - builder_cls, _ = get_attention_backend(backend) - is_decode.append(q_len <= builder_cls.reorder_batch_threshold) + for backend_idx, backend in enumerate(BACKENDS_TO_TEST): + builder_cls, _ = try_get_attention_backend(backend) + if is_spec_decode_test: + query_len_support = getattr( + builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY + ) + supports_spec = query_len_support != QueryLenSupport.SINGLE_ONLY + is_decode.append(supports_spec) + else: + threshold = getattr(builder_cls, "reorder_batch_threshold", None) + query_len_support = getattr( + builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY + ) + within_threshold = q_len <= threshold if threshold else False + if ( + within_threshold + and query_len_support == QueryLenSupport.UNIFORM + and i > 0 + ): + first_q_len = query_lens[0] + within_threshold = q_len == first_q_len + is_decode.append(within_threshold) # Split q into nope and rope components q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) @@ -480,11 +557,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0) sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2) - for i, backend in enumerate(BACKENDS_TO_TEST): - if is_decode[i]: - all_sdpa_outputs[i].append(sdpa_out_i_decode) + for backend_idx, backend in enumerate(BACKENDS_TO_TEST): + if is_decode[backend_idx]: + all_sdpa_outputs[backend_idx].append(sdpa_out_i_decode) else: - all_sdpa_outputs[i].append(sdpa_out_i_prefill) + all_sdpa_outputs[backend_idx].append(sdpa_out_i_prefill) # Inputs for vLLM MLA backends are just the new tokens all_q_vllm.append(q_c) @@ -499,9 +576,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): query_vllm = torch.cat(all_q_vllm, dim=0) kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) - sdpa_outputs = [] - for i, backend in enumerate(BACKENDS_TO_TEST): - sdpa_outputs.append(torch.cat(all_sdpa_outputs[i], dim=0)) + sdpa_outputs = {} + for backend_idx, backend in enumerate(BACKENDS_TO_TEST): + sdpa_outputs[backend] = torch.cat(all_sdpa_outputs[backend_idx], dim=0) # Create mock kv_b_proj using the same weights as reference implementation from vllm.model_executor.layers.linear import ColumnParallelLinear @@ -518,7 +595,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): kv_b_proj_weight = kv_b_proj_weight.view( kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim) ) - mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T) + mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False) # Create metadata using original batch spec common_attn_metadata = create_common_attn_metadata( @@ -539,7 +616,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ) # 4. Run vLLM backends and compare - for i, backend_name in enumerate(BACKENDS_TO_TEST): + for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST): + # Skip backends that don't support spec decode for spec decode tests + if is_spec_decode_test and backend_name not in spec_decode_backends: + continue + backend_output = run_attention_backend( backend_name, kv_cache_spec, @@ -558,14 +639,17 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): mock_kv_b_proj, ) + # Use backend_idx to get the correct SDPA output for this backend + expected_output = sdpa_outputs[backend_name] + # Check shape and dtype consistency - assert backend_output.shape == sdpa_outputs[i].shape, ( + assert backend_output.shape == expected_output.shape, ( f"[{backend_name}] shape {backend_output.shape} != " - f"SDPA shape {sdpa_outputs[i].shape}" + f"SDPA shape {expected_output.shape}" ) - assert backend_output.dtype == sdpa_outputs[i].dtype, ( + assert backend_output.dtype == expected_output.dtype, ( f"[{backend_name}] dtype {backend_output.dtype} != " - f"SDPA dtype {sdpa_outputs[i].dtype}" + f"SDPA dtype {expected_output.dtype}" ) assert torch.isfinite(backend_output).all(), ( @@ -576,12 +660,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): rtol = 1e-2 atol = 5e-1 - max_diff = torch.max(torch.abs(backend_output - sdpa_outputs[i])).item() + max_diff = torch.max(torch.abs(backend_output - expected_output)).item() max_rel_diff = torch.max( - torch.abs(backend_output - sdpa_outputs[i]) / torch.abs(sdpa_outputs[i]) + torch.abs(backend_output - expected_output) / torch.abs(expected_output) ).item() all_close = torch.allclose( - backend_output, sdpa_outputs[i], rtol=rtol, atol=atol + backend_output, expected_output, rtol=rtol, atol=atol ) assert all_close, ( diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index feed66d33b58..15ed7bdc835b 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -3,26 +3,28 @@ """Utility functions for attention-related v1 tests.""" from dataclasses import dataclass -from typing import Optional, Union import pytest import torch -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.abstract import AttentionImpl +from vllm.attention.backends.registry import _Backend, backend_to_class_str from vllm.config import ( CacheConfig, CompilationConfig, DeviceConfig, LoadConfig, ModelConfig, - ModelDType, ParallelConfig, SchedulerConfig, VllmConfig, ) -from vllm.platforms import current_platform -from vllm.utils import resolve_obj_by_qualname -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.config.model import ModelDType +from vllm.utils.import_utils import resolve_obj_by_qualname +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import FullAttentionSpec @@ -117,44 +119,17 @@ def create_common_attn_metadata( ) -def get_attention_backend(backend_name: _Backend): - """Set up attention backend classes for testing. - - Args: - backend_name: Name of the backend ("flash_attn", "flashinfer", etc.) - vllm_config: VllmConfig instance - - Returns: - Tuple of (backend_builder_class, backend_impl_class) - """ - backend_map = { - _Backend.FLASH_ATTN: ( - "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" - if current_platform.is_cuda() - else "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" - ), - _Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend", - _Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501 - _Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501 - _Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", - _Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501 - _Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501 - _Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", - _Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501 - _Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501 - _Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501 - } - - if backend_name not in backend_map: - raise ValueError(f"Unknown backend: {backend_name}") - - backend_class_name = backend_map[backend_name] - +def try_get_attention_backend( + backend: _Backend, +) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]: + """Try to get the attention backend class, skipping test if not found.""" + backend_class_str = backend_to_class_str(backend) try: - backend_class = resolve_obj_by_qualname(backend_class_name) + backend_class = resolve_obj_by_qualname(backend_class_str) return backend_class.get_builder_cls(), backend_class.get_impl_cls() except ImportError as e: - pytest.skip(f"{backend_name} not available: {e}") + pytest.skip(f"{backend_class_str} not available: {e}") + raise AssertionError("unreachable") from None def create_standard_kv_cache_spec(vllm_config: VllmConfig) -> FullAttentionSpec: @@ -174,14 +149,14 @@ def create_vllm_config( model_name: str = "meta-llama/Meta-Llama-3-8B", tensor_parallel_size: int = 1, max_model_len: int = 1024, - dtype: Union[ModelDType, torch.dtype] = "auto", + dtype: ModelDType | torch.dtype = "auto", num_gpu_blocks: int = 1000, block_size: int = 16, max_num_seqs: int = 256, max_num_batched_tokens: int = 8192, enable_chunked_prefill: bool = True, add_mock_model_methods: bool = True, - hf_config_override: Optional[dict] = None, + hf_config_override: dict | None = None, ) -> VllmConfig: """Create a VllmConfig for testing with reasonable defaults.""" @@ -276,7 +251,7 @@ class BackendConfig: name: str env_vars: dict comp_config: dict # compilation config - specific_gpu_arch: Optional[tuple] = None + specific_gpu_arch: tuple | None = None # Define all backend configurations of full cudagraph to be tested diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index a31817ec72b6..df6a5f109874 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -1,26 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib -from typing import Callable, Optional +from collections.abc import Callable +from typing import Any import pytest import torch import vllm.v1.core.kv_cache_utils as kv_cache_utils from vllm.config import ModelConfig, SchedulerConfig, VllmConfig +from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import ( MultiModalFeatureSpec, MultiModalKwargsItem, PlaceholderRange, ) from vllm.sampling_params import SamplingParams -from vllm.utils import GiB_bytes, sha256, sha256_cbor +from vllm.utils.hashing import sha256, sha256_cbor +from vllm.utils.mem_constants import GiB_bytes from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_utils import ( BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, - PrefixCachingMetrics, estimate_max_model_len, generate_block_hash_extra_keys, generate_scheduler_kv_cache_config, @@ -31,6 +33,7 @@ init_none_hash, is_kv_cache_spec_uniform, make_block_hash_with_group_id, + tensor_data, ) from vllm.v1.kv_cache_interface import ( FullAttentionSpec, @@ -42,7 +45,7 @@ SlidingWindowSpec, UniformTypeKVCacheSpecs, ) -from vllm.v1.metrics.stats import PrefixCacheStats +from vllm.v1.metrics.stats import CachingMetrics, PrefixCacheStats from vllm.v1.request import Request pytestmark = pytest.mark.cpu_test @@ -60,12 +63,13 @@ def _auto_init_hash_fn(request): def make_request( request_id: str, - prompt_token_ids: list[int], + prompt_token_ids: list[int] | None, block_size: int = 3, hash_fn: Callable = hash, - mm_positions: Optional[list[PlaceholderRange]] = None, - mm_hashes: Optional[list[str]] = None, - cache_salt: Optional[str] = None, + mm_positions: list[PlaceholderRange] | None = None, + mm_hashes: list[str] | None = None, + cache_salt: str | None = None, + prompt_embeds: torch.Tensor | None = None, ): mm_features = [] if mm_positions is not None: @@ -89,6 +93,7 @@ def make_request( lora_request=None, cache_salt=cache_salt, block_hasher=get_request_block_hasher(block_size, hash_fn), + prompt_embeds=prompt_embeds, ) @@ -449,6 +454,70 @@ def test_generate_block_hash_extra_keys_cache_salt(): assert next_mm_idx == 1 +def test_generate_block_hash_extra_keys_prompt_embeds(): + prompt_embeds = torch.randn(10, 3) + request = make_request( + request_id="0", + prompt_token_ids=None, + mm_positions=None, + mm_hashes=None, + prompt_embeds=prompt_embeds, + ) + + # Test with prompt embeds for the first block + extra_keys, _ = generate_block_hash_extra_keys(request, 0, 5, 0) + expected_embeds = prompt_embeds[0:5] + expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes() + assert extra_keys == (expected_bytes,) + + # Test with prompt embeds for the second block + extra_keys, _ = generate_block_hash_extra_keys(request, 5, 10, 0) + expected_embeds = prompt_embeds[5:10] + expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes() + assert extra_keys == (expected_bytes,) + + +def test_generate_block_hash_extra_keys_different_prompt_embeds(): + prompt_embeds1 = torch.randn(10, 3) + prompt_embeds2 = torch.randn(10, 3) + request1 = make_request( + request_id="0", + prompt_token_ids=None, + mm_positions=None, + mm_hashes=None, + prompt_embeds=prompt_embeds1, + ) + request2 = make_request( + request_id="1", + prompt_token_ids=None, + mm_positions=None, + mm_hashes=None, + prompt_embeds=prompt_embeds2, + ) + + extra_keys1, _ = generate_block_hash_extra_keys(request1, 0, 5, 0) + extra_keys2, _ = generate_block_hash_extra_keys(request2, 0, 5, 0) + assert extra_keys1 != extra_keys2 + + +def test_generate_block_hash_extra_keys_lora(): + request = make_request( + request_id="0", + prompt_token_ids=[_ for _ in range(6)], + ) + + request.lora_request = LoRARequest( + lora_name="test_lora_adapter", lora_int_id=1, lora_path="/path/to/lora" + ) + + extra_keys, _ = generate_block_hash_extra_keys(request, 0, 3, 0) + assert extra_keys == ("test_lora_adapter",) + + request.lora_request = None + extra_keys, _ = generate_block_hash_extra_keys(request, 0, 3, 0) + assert extra_keys is None + + @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_hash_block_tokens(hash_fn): parent_block_hash = BlockHash(b"123") @@ -536,7 +605,7 @@ def test_metrics(): """ Test the prefix caching metrics. """ - metrics = PrefixCachingMetrics(max_recent_requests=5) + metrics = CachingMetrics(max_recent_requests=5) assert metrics.hit_rate == 0.0 metrics.observe(_stats(1, 20, 9)) @@ -568,7 +637,7 @@ def test_metrics_empty_stats(): """ Test the prefix caching metrics with empty stats. """ - metrics = PrefixCachingMetrics(max_recent_requests=5) + metrics = CachingMetrics(max_recent_requests=5) metrics.observe(_stats(0, 0, 0)) metrics.observe(_stats(1, 20, 9)) metrics.observe(_stats(0, 0, 0)) @@ -1537,3 +1606,88 @@ def test_merge_mla_spec(): ] with pytest.raises(AssertionError): kv_cache_specs[0].merge(kv_cache_specs) + + +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) +def test_request_block_hasher_with_prompt_embeds(hash_fn: Callable[[Any], bytes]): + block_size = 3 + num_tokens = 2 * block_size + prompt_token_ids = [_ for _ in range(num_tokens)] + hidden_size = 5 + prompt_embeds = torch.randn((num_tokens, hidden_size)) + + request = make_request( + request_id="0", + prompt_token_ids=prompt_token_ids, + block_size=block_size, + hash_fn=hash_fn, + prompt_embeds=prompt_embeds, + ) + + block_hashes = request.block_hashes + assert len(block_hashes) == 2 + + block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes() + expected_hash1 = hash_fn( + ( + kv_cache_utils.NONE_HASH, + tuple(prompt_token_ids[:block_size]), + (block1_embeds_bytes,), + ) + ) + assert block_hashes[0] == expected_hash1 + + block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes() + expected_hash2 = hash_fn( + ( + block_hashes[0], + tuple(prompt_token_ids[block_size:num_tokens]), + (block2_embeds_bytes,), + ) + ) + assert block_hashes[1] == expected_hash2 + + +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) +def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes]): + block_size = 3 + num_tokens = 2 * block_size + prompt_token_ids = [_ for _ in range(num_tokens)] + hidden_size = 5 + prompt_embeds = torch.randn((num_tokens, hidden_size)) + + request = make_request( + request_id="0", + prompt_token_ids=prompt_token_ids, + block_size=block_size, + hash_fn=hash_fn, + mm_positions=[ + PlaceholderRange(offset=0, length=3), + PlaceholderRange(offset=3, length=3), + ], + mm_hashes=["hash1", "hash2"], + prompt_embeds=prompt_embeds, + ) + + block_hashes = request.block_hashes + assert len(block_hashes) == 2 + + block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes() + expected_hash1 = hash_fn( + ( + kv_cache_utils.NONE_HASH, + tuple(prompt_token_ids[:block_size]), + ("hash1", block1_embeds_bytes), + ) + ) + assert block_hashes[0] == expected_hash1 + + block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes() + expected_hash2 = hash_fn( + ( + block_hashes[0], + tuple(prompt_token_ids[block_size:num_tokens]), + ("hash2", block2_embeds_bytes), + ) + ) + assert block_hashes[1] == expected_hash2 diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index d08c1bcc57bd..837a513cb75e 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -3,7 +3,7 @@ """Compare the with and without prefix caching.""" import copy -from typing import Callable, Optional +from collections.abc import Callable import pytest import torch @@ -16,7 +16,7 @@ PlaceholderRange, ) from vllm.sampling_params import SamplingParams -from vllm.utils import sha256, sha256_cbor +from vllm.utils.hashing import sha256, sha256_cbor from vllm.v1.core.block_pool import BlockHashToBlockMap, BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import ( @@ -55,10 +55,10 @@ def make_request( prompt_token_ids: list[int], block_size: int, hash_fn: Callable, - mm_positions: Optional[list[PlaceholderRange]] = None, - mm_hashes: Optional[list[str]] = None, - prompt_logprobs: Optional[int] = None, - cache_salt: Optional[str] = None, + mm_positions: list[PlaceholderRange] | None = None, + mm_hashes: list[str] | None = None, + prompt_logprobs: int | None = None, + cache_salt: str | None = None, ): mm_features = [] if mm_positions is not None: diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index e78cced2d2db..fba577239682 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses -from typing import Optional from unittest.mock import Mock import pytest @@ -31,7 +30,6 @@ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager -from vllm.v1.structured_output.request import StructuredOutputRequest from .utils import EOS_TOKEN_ID, create_requests, create_scheduler @@ -78,9 +76,7 @@ def test_get_num_unfinished_requests(): (True, 5), ], ) -def test_schedule( - enable_prefix_caching: Optional[bool], prompt_logprobs: Optional[int] -): +def test_schedule(enable_prefix_caching: bool | None, prompt_logprobs: int | None): """Test scheduling. Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs """ @@ -338,10 +334,10 @@ def test_stop_via_update_from_output(): requests[0].request_id: [], requests[1].request_id: [10], }, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -386,10 +382,10 @@ def test_stop_via_update_from_output(): requests[0].request_id: [10, 42], requests[1].request_id: [13], }, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -432,10 +428,10 @@ def test_stop_via_update_from_output(): requests[0].request_id: [10, 11], requests[1].request_id: [], }, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -473,10 +469,10 @@ def test_stop_via_update_from_output(): total_num_scheduled_tokens=3, scheduled_encoder_inputs={}, scheduled_spec_decode_tokens={requests[0].request_id: [EOS_TOKEN_ID, 10]}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -497,6 +493,96 @@ def test_stop_via_update_from_output(): assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11] +def test_check_stop_min_tokens(): + """Test that requests don't stop when min_tokens requirement isn't met.""" + from vllm.v1.core.sched.utils import check_stop + + # Test case 1: num_output_tokens < min_tokens + # Should return False (don't stop) + sampling_params = SamplingParams( + ignore_eos=False, + max_tokens=20, + min_tokens=5, + ) + request = Request( + request_id="0", + prompt_token_ids=[0, 1, 2], + sampling_params=sampling_params, + pooling_params=None, + eos_token_id=EOS_TOKEN_ID, + ) + # Simulate having generated 3 output tokens (less than min_tokens=5) + request.append_output_token_ids([10, 11, EOS_TOKEN_ID]) # EOS token present + + result = check_stop(request, max_model_len=100) + assert result is False, "Should not stop when num_output_tokens<min_tokens" + + # Test case 2: num_output_tokens >= min_tokens + # Should follow normal stopping logic (stop on EOS) + request.append_output_token_ids( + [ + 10, + 11, + 12, + 13, + 14, + EOS_TOKEN_ID, + ] + ) # 6 tokens > min_tokens + + result = check_stop(request, max_model_len=100) + assert result is True, "Should stop on EOS when min_tokens met" + assert request.status == RequestStatus.FINISHED_STOPPED + + # Test case 3: min_tokens = 0, should follow normal stopping logic + sampling_params_no_min = SamplingParams( + ignore_eos=False, + max_tokens=20, + min_tokens=0, + ) + request_no_min = Request( + request_id="1", + prompt_token_ids=[0, 1, 2], + sampling_params=sampling_params_no_min, + pooling_params=None, + eos_token_id=EOS_TOKEN_ID, + ) + request_no_min.append_output_token_ids([10, EOS_TOKEN_ID]) + + result = check_stop(request_no_min, max_model_len=100) + assert result is True, "Should stop on EOS when min_tokens=0" + assert request_no_min.status == RequestStatus.FINISHED_STOPPED + + # Test case 4: min_tokens > 0 with stop token (not EOS) + sampling_params_stop = SamplingParams( + ignore_eos=False, + max_tokens=20, + min_tokens=5, + stop_token_ids=[42], + ) + request_stop = Request( + request_id="2", + prompt_token_ids=[0, 1, 2], + sampling_params=sampling_params_stop, + pooling_params=None, + eos_token_id=EOS_TOKEN_ID, + ) + # Only 3 output tokens, less than min_tokens=5, but has stop token + request_stop.append_output_token_ids([10, 11, 42]) + result = check_stop(request_stop, max_model_len=100) + assert result is False, "Should not stop when num_output_tokens<min_tokens" + + # Test case 5: min_tokens met, should stop on stop token + request_stop.append_output_token_ids( + [10, 11, 12, 13, 14, 42] + ) # 6 tokens >= min_tokens=5 + + result = check_stop(request_stop, max_model_len=100) + assert result is True, "Should stop on stop token when min_tokens met" + assert request_stop.status == RequestStatus.FINISHED_STOPPED + assert request_stop.stop_reason == 42 + + @pytest.mark.parametrize( "enable_prefix_caching, prompt_logprobs", [ @@ -505,7 +591,7 @@ def test_stop_via_update_from_output(): ], ) def test_schedule_concurrent_batches( - enable_prefix_caching: Optional[bool], prompt_logprobs: Optional[int] + enable_prefix_caching: bool | None, prompt_logprobs: int | None ): scheduler = create_scheduler( max_num_batched_tokens=1024, @@ -717,8 +803,10 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): engine_core_outputs[0].scheduler_stats if engine_core_outputs else None ) if expected[0] == 0: + assert scheduler_stats is not None assert scheduler_stats.spec_decoding_stats is None else: + assert scheduler_stats is not None assert scheduler_stats.spec_decoding_stats is not None stats = scheduler_stats.spec_decoding_stats assert stats.num_drafts == expected[0] @@ -811,6 +899,7 @@ def test_kv_connector_basic(): scheduler = create_scheduler( enable_prefix_caching=True, use_kv_connector=True, + disable_hybrid_kv_cache_manager=True, ) NUM_TOTAL_BLOCKS = scheduler.kv_cache_manager.block_pool.get_num_free_blocks() BLOCK_SIZE = scheduler.cache_config.block_size @@ -926,6 +1015,67 @@ def test_kv_connector_basic(): ) +def test_external_prefix_cache_metrics(): + """ + Verify connector prefix cache metrics are updated + correctly when the scheduler processes requests with KV connector hits. + """ + + # Setup Scheduler. + scheduler = create_scheduler( + enable_prefix_caching=False, + use_kv_connector=True, + disable_hybrid_kv_cache_manager=True, + ) + + # Mock connector to simulate a partial external cache hit + NUM_MATCHED_NEW_TOKENS = 4 + scheduler.connector.get_num_new_matched_tokens = Mock(name="method") + scheduler.connector.get_num_new_matched_tokens.return_value = ( + NUM_MATCHED_NEW_TOKENS, + False, + ) + + # --- Prepare simple requests --- + NUM_REQUESTS = 2 + NUM_TOKENS = 8 + MAX_TOKENS = 2 + requests = create_requests( + num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS, + ) + + for req in requests: + scheduler.add_request(req) + + # --- Trigger scheduling and simulate model output --- + output = scheduler.schedule() + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=[r.request_id for r in requests], + req_id_to_index={r.request_id: i for i, r in enumerate(requests)}, + sampled_token_ids=[[1000]] * NUM_REQUESTS, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + + # Update scheduler stats + ecos = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) + + # --- Assertions --- + assert ecos is not None and len(ecos) > 0 + assert ecos[0].scheduler_stats is not None + + external_stats = ecos[0].scheduler_stats.connector_prefix_cache_stats + assert external_stats is not None + + assert external_stats.queries == NUM_TOKENS * NUM_REQUESTS + assert external_stats.hits == NUM_MATCHED_NEW_TOKENS * NUM_REQUESTS + assert external_stats.requests == NUM_REQUESTS + assert external_stats.preempted_requests == 0 + + def test_kv_connector_unable_to_allocate(): """ Test whether scheduler with KVConnector is able to handle @@ -940,6 +1090,7 @@ def test_kv_connector_unable_to_allocate(): use_kv_connector=True, block_size=BLOCK_SIZE, num_blocks=NUM_BLOCKS, + disable_hybrid_kv_cache_manager=True, ) NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler.connector.get_num_new_matched_tokens = Mock(name="method") @@ -1023,6 +1174,7 @@ def test_kv_connector_handles_preemption(): use_kv_connector=True, block_size=BLOCK_SIZE, num_blocks=NUM_BLOCKS, + disable_hybrid_kv_cache_manager=True, ) NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE @@ -1231,14 +1383,15 @@ def create_scheduler_with_priority( model: str = "facebook/opt-125m", max_num_seqs: int = 16, max_num_batched_tokens: int = 8192, - enable_prefix_caching: Optional[bool] = None, + enable_prefix_caching: bool | None = None, long_prefill_token_threshold: int = 0, disable_chunked_mm_input: bool = False, use_kv_connector: bool = False, num_blocks: int = 10000, block_size: int = 16, - max_model_len: Optional[int] = None, - num_speculative_tokens: Optional[int] = None, + max_model_len: int | None = None, + num_speculative_tokens: int | None = None, + disable_hybrid_kv_cache_manager: bool = False, ) -> Scheduler: """Create scheduler with priority policy enabled. @@ -1263,6 +1416,7 @@ def create_scheduler_with_priority( disable_chunked_mm_input=disable_chunked_mm_input, enable_chunked_prefill=True, policy="priority", # Enable priority scheduling + disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager, ) model_config = ModelConfig( model=model, @@ -1293,7 +1447,7 @@ def create_scheduler_with_priority( else None ) - speculative_config: Optional[SpeculativeConfig] = None + speculative_config: SpeculativeConfig | None = None if num_speculative_tokens is not None: speculative_config = SpeculativeConfig( model="ngram", num_speculative_tokens=num_speculative_tokens @@ -1321,18 +1475,19 @@ def create_scheduler_with_priority( kv_cache_config=kv_cache_config, log_stats=True, structured_output_manager=StructuredOutputManager(vllm_config), + block_size=block_size, ) def create_requests_with_priority( num_requests: int, priorities: list[int], - arrival_times: Optional[list[float]] = None, + arrival_times: list[float] | None = None, num_tokens: int = 10, - mm_positions: Optional[list[list[PlaceholderRange]]] = None, + mm_positions: list[list[PlaceholderRange]] | None = None, max_tokens: int = 16, - stop_token_ids: Optional[list[int]] = None, - prompt_logprobs: Optional[int] = None, + stop_token_ids: list[int] | None = None, + prompt_logprobs: int | None = None, starting_idx: int = 0, ): """Create requests with specified priorities and arrival times.""" @@ -1851,7 +2006,6 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): sampling_params=sampling_params, pooling_params=None, eos_token_id=EOS_TOKEN_ID, - structured_output_request=StructuredOutputRequest(sampling_params), ) scheduler.add_request(request) output = scheduler.schedule() @@ -1860,7 +2014,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): assert len(scheduler.waiting) == 1 -def test_priority_scheduling_preemption_when_out_of_kv(): +def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(): """Test that priority scheduling preempts lower priority requests when out of KV cache space.""" # Create scheduler with very limited memory to force preemption @@ -1869,6 +2023,8 @@ def test_priority_scheduling_preemption_when_out_of_kv(): max_num_batched_tokens=200, num_blocks=5, # Can hold 64 tokens (first block is null) block_size=16, # Standard block size + use_kv_connector=True, + disable_hybrid_kv_cache_manager=True, ) # Create a request and schedule it @@ -1880,12 +2036,13 @@ def test_priority_scheduling_preemption_when_out_of_kv(): starting_idx=0, )[0] scheduler.add_request(request_low) + # 1st schedule output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 1 assert len(scheduler.waiting) == 0 assert len(scheduler.running) == 1 - # Simulate model execution + # Simulate model execution - 1st decode model_output = ModelRunnerOutput( req_ids=[request_low.request_id], req_id_to_index={request_low.request_id: 0}, @@ -1906,6 +2063,7 @@ def test_priority_scheduling_preemption_when_out_of_kv(): starting_idx=1, )[0] scheduler.add_request(request_high) + # 2nd schedule output = scheduler.schedule() # KV cache should be full at this point assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0 @@ -1914,7 +2072,7 @@ def test_priority_scheduling_preemption_when_out_of_kv(): assert len(scheduler.waiting) == 0 assert len(scheduler.running) == 2 - # Simulate model execution + # Simulate model execution - 2nd decode requests = [request_low, request_high] model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], @@ -1927,7 +2085,7 @@ def test_priority_scheduling_preemption_when_out_of_kv(): ) scheduler.update_from_output(output, model_output) - # Schedule again - this should trigger preemption + # 3rd schedule - this should trigger preemption # req_low needs 32 tokens = 2 blocks # req_high needs 33 tokens = 3 blocks # so doesn't fit in 4 blocks. @@ -1937,9 +2095,44 @@ def test_priority_scheduling_preemption_when_out_of_kv(): assert len(output.scheduled_new_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 1 assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id + assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED assert len(scheduler.waiting) == 1 assert len(scheduler.running) == 1 + # Simulate model execution - 3rd decode + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, + sampled_token_ids=[[], [100]], + # spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + # Finish the requests to make room for the preempted requests to resume + scheduler.update_from_output(output, model_output) + scheduler.finish_requests(request_high.request_id, RequestStatus.FINISHED_STOPPED) + + # 4th Schedule - this should trigger the resumption + output = scheduler.schedule() + scheduled_cached_reqs = output.scheduled_cached_reqs + resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption + + assert len(output.scheduled_new_reqs) == 0 + assert scheduled_cached_reqs.num_reqs == 1 + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == 1 + + # Preempted request resumed in scheduled_cached_reqs + assert len(resumed_from_preemption) == 1 + assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1 + assert resumed_from_preemption[0] + assert scheduled_cached_reqs.req_ids[0] == request_low.request_id + assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None + # Resumed tokens include 30 prompt tokens and 2 decoded tokens + assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 32 + assert scheduled_cached_reqs.resumed_req_token_ids[0][31] == 100 + @pytest.mark.parametrize( ("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"), diff --git a/tests/v1/core/test_scheduler_e2e.py b/tests/v1/core/test_scheduler_e2e.py index 90f8757ae493..f1df4e95d5f4 100644 --- a/tests/v1/core/test_scheduler_e2e.py +++ b/tests/v1/core/test_scheduler_e2e.py @@ -5,7 +5,7 @@ from vllm import LLM -MODEL = "meta-llama/Llama-3.2-1B" +MODEL = "hmellor/tiny-random-LlamaForCausalLM" PROMPT = "Hello my name is Robert and I" diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 75ef1a5ec165..3f5e1b9eeaf7 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union import torch @@ -18,7 +17,7 @@ PlaceholderRange, ) from vllm.sampling_params import SamplingParams -from vllm.utils import sha256 +from vllm.utils.hashing import sha256 from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.scheduler import Scheduler @@ -37,17 +36,18 @@ def create_scheduler( model: str = "facebook/opt-125m", max_num_seqs: int = 16, max_num_batched_tokens: int = 8192, - enable_prefix_caching: Optional[bool] = None, + enable_prefix_caching: bool | None = None, long_prefill_token_threshold: int = 0, disable_chunked_mm_input: bool = False, use_kv_connector: bool = False, num_blocks: int = 10000, block_size: int = 16, - max_model_len: Optional[int] = None, - num_speculative_tokens: Optional[int] = None, + max_model_len: int | None = None, + num_speculative_tokens: int | None = None, skip_tokenizer_init: bool = False, async_scheduling: bool = False, -) -> Union[Scheduler, AsyncScheduler]: + disable_hybrid_kv_cache_manager: bool = False, +) -> Scheduler | AsyncScheduler: """Create scheduler under test. Args: @@ -71,6 +71,7 @@ def create_scheduler( disable_chunked_mm_input=disable_chunked_mm_input, enable_chunked_prefill=True, async_scheduling=async_scheduling, + disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager, ) model_config = ModelConfig( model=model, @@ -102,7 +103,7 @@ def create_scheduler( else None ) - speculative_config: Optional[SpeculativeConfig] = None + speculative_config: SpeculativeConfig | None = None if num_speculative_tokens is not None: speculative_config = SpeculativeConfig( model="ngram", num_speculative_tokens=num_speculative_tokens @@ -129,6 +130,7 @@ def create_scheduler( return scheduler_cls( vllm_config=vllm_config, kv_cache_config=kv_cache_config, + block_size=block_size, log_stats=True, structured_output_manager=StructuredOutputManager(vllm_config), ) @@ -140,10 +142,10 @@ def create_scheduler( def create_requests( num_requests: int, num_tokens: int = 10, - mm_positions: Optional[list[list[PlaceholderRange]]] = None, + mm_positions: list[list[PlaceholderRange]] | None = None, max_tokens: int = 16, - stop_token_ids: Optional[list[int]] = None, - prompt_logprobs: Optional[int] = None, + stop_token_ids: list[int] | None = None, + prompt_logprobs: int | None = None, same_prompt: bool = False, block_size: int = 16, ) -> list[Request]: diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index 59841a446db3..bb953e5c70c8 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -11,7 +11,7 @@ from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.config import ( CompilationConfig, - CompilationLevel, + CompilationMode, CUDAGraphMode, ParallelConfig, SchedulerConfig, @@ -34,15 +34,18 @@ def forward(self, x): def _create_vllm_config( - compilation_config: CompilationConfig, max_num_seqs: int = 8 + compilation_config: CompilationConfig, + max_num_seqs: int = 8, + lora_config: bool = False, ) -> MagicMock: mock_config = MagicMock(spec=VllmConfig) mock_config.compilation_config = compilation_config mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs) mock_config.parallel_config = ParallelConfig() - + if not lora_config: + mock_config.lora_config = None # Mimic the behavior of VllmConfig.__post_init__() - if compilation_config.level == CompilationLevel.PIECEWISE: + if compilation_config.mode == CompilationMode.VLLM_COMPILE: compilation_config.set_splitting_ops_for_v1() return mock_config @@ -50,27 +53,39 @@ def _create_vllm_config( class TestCudagraphDispatcher: @pytest.mark.parametrize( - "case_id,cudagraph_mode_str,compilation_level", + "cudagraph_mode_str,compilation_mode,lora_config", [ # Test case 0: Full CG for mixed batches, no separate routine - (0, "FULL", CompilationLevel.NO_COMPILATION), + ("FULL", CompilationMode.NONE, False), # Test case 1: Full CG for uniform batches, piecewise for mixed - (1, "FULL_AND_PIECEWISE", CompilationLevel.NO_COMPILATION), + ("FULL_AND_PIECEWISE", CompilationMode.NONE, False), # Test case 2: Full CG for uniform batches, no CG for mixed - (2, "FULL_DECODE_ONLY", CompilationLevel.NO_COMPILATION), - # Test case 3: Piecewise for all - (3, "PIECEWISE", CompilationLevel.PIECEWISE), + ("FULL_DECODE_ONLY", CompilationMode.NONE, False), + # Test case 3: PIECEWISE for all + ("PIECEWISE", CompilationMode.VLLM_COMPILE, False), + # Test case 4: PIECEWISE for all, specialize LoRA cases + ("PIECEWISE", CompilationMode.VLLM_COMPILE, True), ], ) - def test_dispatcher(self, cudagraph_mode_str, compilation_level): + def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config): # Setup dispatcher comp_config = CompilationConfig( cudagraph_mode=cudagraph_mode_str, - level=compilation_level, + mode=compilation_mode, cudagraph_capture_sizes=[1, 8], ) - config = _create_vllm_config(comp_config, max_num_seqs=8) + config = _create_vllm_config( + comp_config, max_num_seqs=8, lora_config=lora_config + ) + if ( + cudagraph_mode_str == "FULL_AND_PIECEWISE" + and compilation_mode == CompilationMode.NONE + ): + with pytest.raises(AssertionError): + dispatcher = CudagraphDispatcher(config) + return + dispatcher = CudagraphDispatcher(config) dispatcher.initialize_cudagraph_keys( cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1 @@ -78,17 +93,24 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_level): # Verify the key is initialized correctly if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]: - assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2 + assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == ( + 4 if lora_config else 2 + ) else: assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0 if cudagraph_mode_str not in ["NONE", "PIECEWISE"]: - assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2 + assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == ( + 4 if lora_config else 2 + ) else: assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0 # Test dispatch logic # 1. non-uniform batch, size in cudagraph size list - desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False) + desc_full_exact = BatchDescriptor( + num_tokens=8, + uniform_decode=False, + ) rt_mode, key = dispatcher.dispatch(desc_full_exact) if cudagraph_mode_str == "FULL": assert rt_mode == CUDAGraphMode.FULL @@ -138,7 +160,6 @@ def setup_method(self): self.persistent_input_buffer = torch.zeros(1, 10, device="cuda") self.input_tensor = torch.randn(1, 10, device="cuda") - @create_new_process_for_each_test("spawn") def test_capture_and_replay(self): wrapper = CUDAGraphWrapper( self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL @@ -192,7 +213,6 @@ def test_capture_and_replay(self): eager_output = self.model(self.input_tensor) torch.testing.assert_close(eager_output, output2) - @create_new_process_for_each_test("spawn") def test_bypass_on_mode_mismatch(self): wrapper = CUDAGraphWrapper( self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL @@ -216,7 +236,6 @@ def test_bypass_on_mode_mismatch(self): mock_forward.assert_called_once() assert not wrapper.concrete_cudagraph_entries - @create_new_process_for_each_test("spawn") def test_bypass_on_mode_none(self): wrapper = CUDAGraphWrapper( self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL @@ -242,7 +261,7 @@ class TestCudagraphIntegration: def setup_method(self): # only FULL mode for non-uniform batches self.comp_config = CompilationConfig( - level=CompilationLevel.PIECEWISE, + mode=CompilationMode.VLLM_COMPILE, cudagraph_mode="FULL", cudagraph_capture_sizes=[10, 20], ) diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py index 8c8148ae2094..d6bde16eba36 100644 --- a/tests/v1/cudagraph/test_cudagraph_mode.py +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -10,7 +10,7 @@ from tests.utils import wait_for_gpu_memory_to_clear from tests.v1.attention.utils import full_cg_backend_configs as backend_configs from vllm import LLM -from vllm.config import CompilationConfig +from vllm.config import CompilationConfig, CompilationMode from vllm.platforms import current_platform @@ -73,7 +73,7 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte gpu_memory_utilization=0.45, max_model_len=1024, compilation_config=CompilationConfig( - level=3, cudagraph_mode=cudagraph_mode + mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode ), ) llm.generate(["Hello, my name is"] * 10) @@ -90,33 +90,28 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte ) -# test cudagraph_mode with different compilation level. -# (backend_name, cudagraph_mode, compilation_level, supported) +# test cudagraph_mode with different compilation mode. +# (backend_name, cudagraph_mode, compilation_mode, supported) combo_cases_2 = [ - ("FA2", "FULL", 0, True), # no compilation + full cudagraph - ("FA2", "FULL", 3, True), # piecewise compilation + full cudagraph - ("FA2", "PIECEWISE", 0, False), # no compilation + piecewise cudagraph - ("FA2", "PIECEWISE", 3, True), # piecewise compilation + piecewise cudagraph - ( - "FA2", - "FULL_AND_PIECEWISE", - 0, - False, - ), # piecewise cudagraph not supported without piecewise compilation - ("FA2", "FULL_AND_PIECEWISE", 3, True), - ("FA2", "FULL_DECODE_ONLY", 0, True), - ("FA2", "FULL_DECODE_ONLY", 3, True), - ("FA2", "NONE", 0, True), # no compilation + no cudagraph - ("FA2", "NONE", 3, True), # piecewise compilation + no cudagraph + ("FA2", "FULL", CompilationMode.NONE, True), + ("FA2", "FULL", CompilationMode.VLLM_COMPILE, True), + ("FA2", "PIECEWISE", CompilationMode.NONE, False), + ("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True), + ("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False), + ("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True), + ("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True), + ("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True), + ("FA2", "NONE", CompilationMode.NONE, True), + ("FA2", "NONE", CompilationMode.VLLM_COMPILE, True), ] @pytest.mark.parametrize( - "backend_name,cudagraph_mode,compilation_level,supported", combo_cases_2 + "backend_name,cudagraph_mode,compilation_mode,supported", combo_cases_2 ) -def test_cudagraph_compilation_combo(combo_case): - backend_name, cudagraph_mode, compilation_level, supported = combo_case - +def test_cudagraph_compilation_combo( + backend_name, cudagraph_mode, compilation_mode, supported +): env_vars = backend_configs[backend_name].env_vars with temporary_environ(env_vars), ExitStack() as stack: @@ -130,7 +125,7 @@ def test_cudagraph_compilation_combo(combo_case): gpu_memory_utilization=0.45, max_model_len=1024, compilation_config=CompilationConfig( - level=compilation_level, cudagraph_mode=cudagraph_mode + mode=compilation_mode, cudagraph_mode=cudagraph_mode ), ) llm.generate(["Hello, my name is"] * 10) diff --git a/tests/v1/distributed/test_async_llm_dp.py b/tests/v1/distributed/test_async_llm_dp.py index 75314dc37303..98d6ef7dbf44 100644 --- a/tests/v1/distributed/test_async_llm_dp.py +++ b/tests/v1/distributed/test_async_llm_dp.py @@ -5,7 +5,6 @@ import os from contextlib import ExitStack from dataclasses import dataclass -from typing import Optional import pytest @@ -17,7 +16,7 @@ from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.core_client import DPAsyncMPClient from vllm.v1.metrics.loggers import StatLoggerBase -from vllm.v1.metrics.stats import IterationStats, SchedulerStats +from vllm.v1.metrics.stats import IterationStats, MultiModalCacheStats, SchedulerStats DP_SIZE = int(os.getenv("DP_SIZE", 2)) @@ -35,8 +34,8 @@ async def generate( prompt: PromptType, output_kind: RequestOutputKind, max_tokens: int, - prompt_logprobs: Optional[int] = None, - data_parallel_rank: Optional[int] = None, + prompt_logprobs: int | None = None, + data_parallel_rank: int | None = None, ) -> tuple[int, str]: # Ensure generate doesn't complete too fast for cancellation test. await asyncio.sleep(0.2) @@ -79,6 +78,9 @@ async def generate( async def test_load( output_kind: RequestOutputKind, data_parallel_backend: str, async_scheduling: bool ): + if async_scheduling and data_parallel_backend == "ray": + # TODO(NickLucche) Re-enable when async scheduling is supported + pytest.skip("Async scheduling is not supported with ray") stats_loggers = {} @dataclass @@ -91,8 +93,9 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): def record( self, - scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats], + scheduler_stats: SchedulerStats | None, + iteration_stats: IterationStats | None, + mm_cache_stats: MultiModalCacheStats | None = None, engine_idx: int = 0, ): if iteration_stats: diff --git a/tests/v1/distributed/test_internal_lb_dp.py b/tests/v1/distributed/test_internal_lb_dp.py index 452d3682e65d..8f7459e95ef6 100644 --- a/tests/v1/distributed/test_internal_lb_dp.py +++ b/tests/v1/distributed/test_internal_lb_dp.py @@ -5,7 +5,7 @@ import threading import time import traceback -from typing import Optional, cast +from typing import cast import openai # use the official client for correctness check import pytest @@ -46,7 +46,7 @@ def __init__( self.tp_size = tp_size self.api_server_count = api_server_count self.base_server_args = base_server_args - self.servers: list[Optional[tuple[RemoteOpenAIServer, list[str]]]] = [None] * ( + self.servers: list[tuple[RemoteOpenAIServer, list[str]] | None] = [None] * ( dp_size // dp_per_node ) self.server_threads: list[threading.Thread] = [] @@ -175,7 +175,7 @@ def __init__( self.tp_size = tp_size self.api_server_count = api_server_count self.base_server_args = base_server_args - self.servers: list[Optional[tuple[RemoteOpenAIServer, list[str]]]] = [None] * 2 + self.servers: list[tuple[RemoteOpenAIServer, list[str]] | None] = [None] * 2 self.server_threads: list[threading.Thread] = [] def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: diff --git a/tests/v1/e2e/test_async_sched_and_preempt.py b/tests/v1/e2e/test_async_sched_and_preempt.py new file mode 100644 index 000000000000..7ad9606a66df --- /dev/null +++ b/tests/v1/e2e/test_async_sched_and_preempt.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import pytest +import torch._dynamo.config as dynamo_config + +from vllm import SamplingParams + +from ...conftest import VllmRunner +from ...models.utils import check_outputs_equal + +MODEL = "Qwen/Qwen3-0.6B" + + +@dynamo_config.patch(cache_size_limit=16) +def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): + """Test consistency of combos of async scheduling, preemption, + uni/multiproc executor, and various sampling parameters.""" + + first_prompt = ( + "The following numbers of the sequence " + + ", ".join(str(i) for i in range(10)) + + " are:" + ) + example_prompts = [first_prompt, "In one word, the capital of France is "] + [ + f"Tell me about the number {i}: " for i in range(32) + ] + + sampling_param_tests: list[dict[str, Any]] = [ + dict(), + # dict(min_tokens=20), + dict(presence_penalty=-1.0), + dict(bad_words=["the", " the"]), + ] + + default_params = dict( + temperature=0.0, # greedy + max_tokens=20, + ) + + with monkeypatch.context() as m: + m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") + # m.setenv("VLLM_BATCH_INVARIANT", "1") + + outputs: list[tuple[str, list]] = [] + for test_preemption in [False, True]: + for executor in ["mp", "uni"]: + for async_scheduling in [False, True]: + cache_arg: dict[str, Any] = ( + dict(num_gpu_blocks_override=32) + if test_preemption + else dict(gpu_memory_utilization=0.7) + ) + test_config = ( + f"executor={executor}, preemption={test_preemption}," + f" async_sched={async_scheduling}" + ) + print("-" * 80) + print(f"---- TESTING: {test_config}") + print("-" * 80) + with VllmRunner( + MODEL, + max_model_len=512, + enforce_eager=True, + async_scheduling=async_scheduling, + distributed_executor_backend=executor, + dtype="float32", # avoid precision errors + **cache_arg, + ) as vllm_model: + results = [] + for override_params in sampling_param_tests: + print(f"----------- RUNNING PARAMS: {override_params}") + results.append( + vllm_model.generate( + example_prompts, + sampling_params=SamplingParams( + **default_params, **override_params + ), + ) + ) + + if not outputs: + # First check that the different parameter configs + # actually result in different output. + for other_test, params in zip( + results[1:], sampling_param_tests[1:] + ): + with pytest.raises(AssertionError): + check_outputs_equal( + outputs_0_lst=results[0], + outputs_1_lst=other_test, + name_0=f"baseline params={params}", + name_1=f"other params={params}", + ) + + outputs.append((test_config, results)) + + baseline_config, baseline_tests = outputs[0] + + for test_config, test_outputs in outputs[1:]: + for base_outs, test_outs, params in zip( + baseline_tests, test_outputs, sampling_param_tests + ): + check_outputs_equal( + outputs_0_lst=base_outs, + outputs_1_lst=test_outs, + name_0=f"baseline=[{baseline_config}], params={params}", + name_1=f"config=[{test_config}], params={params}", + ) + + print(f"PASSED: config=[{test_config}], params={params}") diff --git a/tests/v1/e2e/test_kv_sharing_fast_prefill.py b/tests/v1/e2e/test_kv_sharing_fast_prefill.py index 89e5f26ac627..f2c6d1c1fd1a 100644 --- a/tests/v1/e2e/test_kv_sharing_fast_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_fast_prefill.py @@ -7,7 +7,7 @@ import torch from vllm import LLM, SamplingParams -from vllm.config import CompilationConfig, CompilationLevel +from vllm.config import CompilationConfig, CompilationMode from vllm.distributed import cleanup_dist_env_and_memory from ...utils import fork_new_process_for_each_test @@ -75,9 +75,9 @@ def test_kv_sharing_fast_prefill( # This allows vLLM compilation backend to handle allocating and # managing buffers for cudagraph cudagraph_copy_inputs=True, - level=CompilationLevel.PIECEWISE + mode=CompilationMode.VLLM_COMPILE if not enforce_eager - else CompilationLevel.NO_COMPILATION, + else CompilationMode.NONE, ) with monkeypatch.context() as m: diff --git a/tests/v1/e2e/test_min_tokens.py b/tests/v1/e2e/test_min_tokens.py index e00a3d58debe..ec7ee0c3ebe6 100644 --- a/tests/v1/e2e/test_min_tokens.py +++ b/tests/v1/e2e/test_min_tokens.py @@ -13,8 +13,6 @@ 5) Multiple stop conditions """ -from typing import Optional, Union - import pytest from vllm import LLM, SamplingParams @@ -33,9 +31,9 @@ def __init__( name: str, min_tokens: int, max_tokens: int, - stop: Optional[Union[str, list[str]]] = None, - expected_min_len: Optional[int] = None, - expected_exact_len: Optional[int] = None, + stop: str | list[str] | None = None, + expected_min_len: int | None = None, + expected_exact_len: int | None = None, ): self.name = name self.min_tokens = min_tokens diff --git a/tests/v1/e2e/test_pooling_chunked_prefill.py b/tests/v1/e2e/test_pooling_chunked_prefill.py new file mode 100644 index 000000000000..a196e359920d --- /dev/null +++ b/tests/v1/e2e/test_pooling_chunked_prefill.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch.nn as nn + +from vllm.platforms import current_platform + +prompt = """ +Generals gathered in their masses +Just like witches at black masses +Evil minds that plot destruction +Sorcerer of death's construction +In the fields, the bodies burning +As the war machine keeps turning +Death and hatred to mankind +Poisoning their brainwashed minds +Oh, Lord, yeah + +Politicians hide themselves away +They only started the war +Why should they go out to fight? +They leave that all to the poor, yeah +Time will tell on their power minds +Making war just for fun +Treating people just like pawns in chess +Wait till their judgment day comes, yeah + +Now, in darkness, world stops turning +Ashes where their bodies burning +No more war pigs have the power +Hand of God has struck the hour +Day of Judgment, God is calling +On their knees, the war pigs crawling +Begging mercies for their sins +Satan, laughing, spreads his wings +Oh, Lord, yeah +""" + + +class WrapperPooler(nn.Module): + def __init__(self, pooler): + super().__init__() + self.pooler = pooler + self.chunks = [] + + def get_pooling_updates(self, task): + return self.pooler.get_pooling_updates(task) + + def forward( + self, + hidden_states, + pooling_metadata, + ): + self.chunks.append(hidden_states.shape[0]) + return self.pooler(hidden_states, pooling_metadata) + + +def inject_pooler(self): + model = self.get_model() + wrapper = WrapperPooler(model.pooler) + model.pooler = wrapper + + +def retrieve_chunks(self): + model = self.get_model() + chunks = model.pooler.chunks + model.pooler.chunks = [] + return chunks + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available") +def test_pooling_chunked_prefill(vllm_runner, monkeypatch): + """Test chunked prefill for pooling models with LastPool.""" + + with monkeypatch.context() as m: + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + model_id = "Qwen/Qwen3-Embedding-0.6B" + + chunk_size = 10 + + # Set chunking parameters to force chunked prefill + # Note: Chunked prefill is automatically handled by vLLM + # internally based on the model size and prompt + with vllm_runner( + model_id, + runner="pooling", + long_prefill_token_threshold=chunk_size, + tensor_parallel_size=1, + enforce_eager=True, + enable_chunked_prefill=True, + ) as llm: + llm.get_llm().llm_engine.collective_rpc(inject_pooler) + + tokenizer = llm.get_llm().get_tokenizer() + tokens = tokenizer(prompt)["input_ids"] + prompt_len = len(tokens) + full_chunks, last_chunk = divmod(prompt_len, chunk_size) + expected_chunks = [chunk_size] * full_chunks + if last_chunk: + expected_chunks.append(last_chunk) + llm.embed([prompt]) + chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0] + + # Check that PoolerWrapper was called and chunks were received + assert len(chunks) > 1 + assert chunks == expected_chunks + + # Disable chunked prefill + with vllm_runner( + model_id, + runner="pooling", + tensor_parallel_size=1, + enforce_eager=True, + ) as llm: + llm.get_llm().llm_engine.collective_rpc(inject_pooler) + llm.embed([prompt]) + chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0] + + # Check that PoolerWrapper was called and no chunks were received + assert len(chunks) == 1 + assert chunks[0] == prompt_len + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available") +def test_pooling_prefix_cache(vllm_runner, monkeypatch): + """Test chunked prefill for pooling models with LastPool.""" + + verses = prompt.split("\n\n") + + with monkeypatch.context() as m: + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + model_id = "Qwen/Qwen3-Embedding-0.6B" + + with vllm_runner( + model_id, + runner="pooling", + enable_prefix_caching=True, + tensor_parallel_size=1, + enforce_eager=True, + ) as llm: + llm.get_llm().llm_engine.collective_rpc(inject_pooler) + tokenizer = llm.get_llm().get_tokenizer() + + prompt1 = "\n\n".join([verses[0], verses[1]]) + prompt2 = "\n\n".join([verses[0], verses[2]]) + tokens1 = tokenizer(prompt1)["input_ids"] + tokens2 = tokenizer(prompt2)["input_ids"] + prompt1_len = len(tokens1) + prompt2_len = len(tokens2) + + llm.embed([prompt1]) + chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0] + + assert len(chunks) == 1 + assert chunks[0] == prompt1_len + + llm.embed([prompt2]) + chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0] + + assert len(chunks) == 1 + assert chunks[0] <= prompt1_len + assert chunks[0] < prompt2_len + + cache_config = llm.get_llm().llm_engine.cache_config + print(f"{cache_config=}") + # Prefixes are cached in blocks + assert (prompt2_len - chunks[0]) % cache_config.block_size == 0 diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index fbbbd0389c26..7dbdf0ca0710 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -1,9 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import random -from typing import Any, Union +from typing import Any import pytest import torch @@ -34,7 +32,7 @@ def get_test_prompts(mm_enabled: bool): for kind in random_prompt_type_choices: word_choices = ["test", "temp", "hello", "where"] word = random.choice(word_choices) - prompt: Union[str, list[dict[str, Any]]] = "" + prompt: str | list[dict[str, Any]] = "" if kind == "repeat": prompt = f""" please repeat the word '{word}' 10 times. diff --git a/tests/v1/engine/conftest.py b/tests/v1/engine/conftest.py index c5c5d35b83c3..283a76dab672 100644 --- a/tests/v1/engine/conftest.py +++ b/tests/v1/engine/conftest.py @@ -6,6 +6,7 @@ from transformers import AutoTokenizer from tests.v1.engine.utils import ( + FULL_STRINGS, NUM_PROMPT_LOGPROBS_UNDER_TEST, NUM_SAMPLE_LOGPROBS_UNDER_TEST, PROMPT_LEN, @@ -18,8 +19,6 @@ from ...distributed.conftest import publisher_config, random_port # noqa: F401 -from tests.v1.engine.utils import FULL_STRINGS # isort: skip - EngineCoreSampleLogprobsType = list[tuple[torch.Tensor, torch.Tensor]] EngineCorePromptLogprobsType = tuple[torch.Tensor, torch.Tensor] diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 444d771a18d6..c9605ea1b07c 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -3,7 +3,6 @@ import asyncio from contextlib import ExitStack -from typing import Optional from unittest.mock import MagicMock import pytest @@ -16,9 +15,14 @@ from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind -from vllm.utils import set_default_torch_num_threads +from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.v1.engine.async_llm import AsyncLLM -from vllm.v1.metrics.loggers import LoggingStatLogger +from vllm.v1.metrics.loggers import ( + AggregatedLoggingStatLogger, + LoggingStatLogger, + PerEngineStatLoggerAdapter, + PrometheusStatLogger, +) if not current_platform.is_cuda(): pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) @@ -53,8 +57,8 @@ async def generate( output_kind: RequestOutputKind, max_tokens: int, n: int = 1, - prompt_logprobs: Optional[int] = None, - cancel_after: Optional[int] = None, + prompt_logprobs: int | None = None, + cancel_after: int | None = None, ) -> tuple[int, str]: # Ensure generate doesn't complete too fast for cancellation test. await asyncio.sleep(0.2) @@ -385,6 +389,12 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.log = MagicMock() +class MockAggregatedStatLogger(AggregatedLoggingStatLogger): + def __init__(self, vllm_config: VllmConfig, engine_indexes: list[int]): + super().__init__(vllm_config, engine_indexes) + self.log = MagicMock() + + @pytest.mark.asyncio async def test_customize_loggers(monkeypatch): """Test that we can customize the loggers. @@ -402,10 +412,45 @@ async def test_customize_loggers(monkeypatch): await engine.do_log_stats() - stat_loggers = engine.logger_manager.per_engine_logger_dict - assert len(stat_loggers) == 1 - assert len(stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger - stat_loggers[0][0].log.assert_called_once() + stat_loggers = engine.logger_manager.stat_loggers + assert ( + len(stat_loggers) == 3 + ) # MockLoggingStatLogger + LoggingStatLogger + Promethus Logger + print(f"{stat_loggers=}") + stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once() + assert isinstance(stat_loggers[1], PerEngineStatLoggerAdapter) + assert isinstance(stat_loggers[1].per_engine_stat_loggers[0], LoggingStatLogger) + assert isinstance(stat_loggers[2], PrometheusStatLogger) + + +@pytest.mark.asyncio +async def test_customize_aggregated_loggers(monkeypatch): + """Test that we can customize the aggregated loggers. + If a customized logger is provided at the init, it should + be added to the default loggers. + """ + + with monkeypatch.context() as m, ExitStack() as after: + m.setenv("VLLM_USE_V1", "1") + + with set_default_torch_num_threads(1): + engine = AsyncLLM.from_engine_args( + TEXT_ENGINE_ARGS, + stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger], + ) + after.callback(engine.shutdown) + + await engine.do_log_stats() + + stat_loggers = engine.logger_manager.stat_loggers + assert len(stat_loggers) == 4 + # MockLoggingStatLogger + MockAggregatedStatLogger + # + LoggingStatLogger + PrometheusStatLogger + stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once() + stat_loggers[1].log.assert_called_once() + assert isinstance(stat_loggers[2], PerEngineStatLoggerAdapter) + assert isinstance(stat_loggers[2].per_engine_stat_loggers[0], LoggingStatLogger) + assert isinstance(stat_loggers[3], PrometheusStatLogger) @pytest.mark.asyncio(scope="module") @@ -545,9 +590,9 @@ async def collect_outputs( prompt: PromptType, sampling_params: SamplingParams, outputs_list: list[RequestOutput], -) -> Optional[RequestOutput]: +) -> RequestOutput | None: """Helper to collect outputs and return the final one.""" - final_output: Optional[RequestOutput] = None + final_output: RequestOutput | None = None async for output in engine.generate( request_id=request_id, prompt=prompt, sampling_params=sampling_params ): diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 997b2b74bb6b..becedb59f644 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -12,10 +12,11 @@ from vllm import SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform -from vllm.utils import set_default_torch_num_threads +from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore -from vllm.v1.executor.abstract import Executor, UniProcExecutor +from vllm.v1.executor.abstract import Executor +from vllm.v1.executor.uniproc_executor import UniProcExecutor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import ModelRunnerOutput @@ -24,9 +25,11 @@ if not current_platform.is_cuda(): pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) -MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME) -PROMPT = "Hello my name is Robert and I love quantization kernels" +# test_engine_core_concurrent_batches assumes exactly 12 tokens per prompt. +# Adjust prompt if changing model to maintain 12-token length. +PROMPT = "I am Gyoubu Masataka Oniwa" PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index bc04d1f93f95..770560a5e549 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -8,7 +8,7 @@ import uuid from dataclasses import dataclass from threading import Thread -from typing import Any, Optional, Union +from typing import Any from unittest.mock import MagicMock import pytest @@ -21,7 +21,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.usage.usage_lib import UsageContext -from vllm.utils import set_default_torch_num_threads +from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore from vllm.v1.engine.core_client import AsyncMPClient, EngineCoreClient, SyncMPClient @@ -41,7 +41,7 @@ def make_request( - params: SamplingParams, prompt_tokens_ids: Optional[list[int]] = None + params: SamplingParams, prompt_tokens_ids: list[int] | None = None ) -> EngineCoreRequest: if not prompt_tokens_ids: prompt_tokens_ids = PROMPT_TOKENS @@ -113,9 +113,7 @@ async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict): # Dummy utility function to monkey-patch into engine core. -def echo( - self, msg: str, err_msg: Optional[str] = None, sleep: Optional[float] = None -) -> str: +def echo(self, msg: str, err_msg: str | None = None, sleep: float | None = None) -> str: print(f"echo util function called: {msg}, {err_msg}") if sleep is not None: time.sleep(sleep) @@ -317,7 +315,7 @@ def echo_dc( self, msg: str, return_list: bool = False, -) -> Union[MyDataclass, list[MyDataclass]]: +) -> MyDataclass | list[MyDataclass]: print(f"echo dc util function called: {msg}") val = None if msg is None else MyDataclass(msg) # Return dataclass to verify support for returning custom types @@ -330,7 +328,7 @@ def echo_dc_dict( self, msg: str, return_dict: bool = False, -) -> Union[MyDataclass, dict[str, MyDataclass]]: +) -> MyDataclass | dict[str, MyDataclass]: print(f"echo dc dict util function called: {msg}") val = None if msg is None else MyDataclass(msg) # Return dict of dataclasses to verify support for returning dicts diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 3f6f2211556f..c1d5f8af7917 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import random from typing import TYPE_CHECKING @@ -13,6 +11,8 @@ if TYPE_CHECKING: from tests.conftest import VllmRunner +else: + VllmRunner = object MODEL = "facebook/opt-125m" DTYPE = "half" diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 9ebf7f09503e..28ebe5166d96 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -3,7 +3,6 @@ import math import time -from typing import Optional import pytest @@ -118,13 +117,13 @@ def test_incremental_detokenization( def _validate_logprobs( gen_tokens: dict[str, list[int]], - gen_logprobs: dict[str, Optional[SampleLogprobs]], - gen_prompt_logprobs: dict[str, Optional[PromptLogprobs]], + gen_logprobs: dict[str, SampleLogprobs | None], + gen_prompt_logprobs: dict[str, PromptLogprobs | None], gen_cumulative_logprob: dict[str, float], dtv: DummyOutputProcessorTestVectors, request_id_list: list[str], - num_sample_logprobs: Optional[int], - num_prompt_logprobs: Optional[int], + num_sample_logprobs: int | None, + num_prompt_logprobs: int | None, ) -> None: for req_idx, req_id in enumerate(request_id_list): new_tokens = gen_tokens[req_id] @@ -413,8 +412,8 @@ def _validate_logprobs( @pytest.mark.parametrize("num_prompt_logprobs", [None, NUM_PROMPT_LOGPROBS_UNDER_TEST]) def test_logprobs_processor( request_output_kind: RequestOutputKind, - num_sample_logprobs: Optional[int], - num_prompt_logprobs: Optional[int], + num_sample_logprobs: int | None, + num_prompt_logprobs: int | None, dummy_test_vectors, ): output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) @@ -530,7 +529,7 @@ def test_logprobs_processor( ) def test_stop_token( include_stop_str_in_output: bool, - num_sample_logprobs: Optional[int], + num_sample_logprobs: int | None, stop_token_type: str, ignore_eos: bool, dummy_test_vectors, @@ -696,7 +695,7 @@ def test_stop_token( @pytest.mark.parametrize("num_sample_logprobs", [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) def test_stop_string( include_stop_str_in_output: bool, - num_sample_logprobs: Optional[int], + num_sample_logprobs: int | None, dummy_test_vectors, ): output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) diff --git a/tests/v1/engine/test_processor_multi_modal_uuids.py b/tests/v1/engine/test_processor_multi_modal_uuids.py index 2f73756ff615..cb6865e42ef8 100644 --- a/tests/v1/engine/test_processor_multi_modal_uuids.py +++ b/tests/v1/engine/test_processor_multi_modal_uuids.py @@ -65,7 +65,7 @@ def __init__(self, gb: float): device_config=DeviceConfig(device="cpu"), ) - return Processor(vllm_config) + return Processor(vllm_config, tokenizer=None) def test_multi_modal_uuids_length_mismatch_raises(monkeypatch): diff --git a/tests/v1/engine/utils.py b/tests/v1/engine/utils.py index 9b720f6eb668..23684a2c55ce 100644 --- a/tests/v1/engine/utils.py +++ b/tests/v1/engine/utils.py @@ -3,7 +3,7 @@ import random from dataclasses import dataclass -from typing import Optional, Union +from typing import TypeAlias import torch from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -12,7 +12,7 @@ from vllm.v1.engine import EngineCoreOutput, FinishReason from vllm.v1.outputs import LogprobsLists, LogprobsTensors -GeneralTokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] +GeneralTokenizerType: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast # Number of sample logprobs to request when testing sample logprobs NUM_SAMPLE_LOGPROBS_UNDER_TEST = 5 @@ -332,16 +332,15 @@ def __init__( # For each request, for each sampled token offset, # a tuple of # (list of topk token ids, list of sample logprob vals, rank) - generated_logprobs_raw: Optional[ - list[list[tuple[list[int], list[float], int]]] - ] = None, + generated_logprobs_raw: list[list[tuple[list[int], list[float], int]]] + | None = None, # For each request, a tuple of # (prompt logprob val matrix, prompt logprob tok id matrix); # each matrix has dimensions # (num prompt toks) x (num prompt logprobs+1) - prompt_logprobs_raw: Optional[list[LogprobsTensors]] = None, - eos_token_id: Optional[int] = None, - stop_token_ids: Optional[list[int]] = None, + prompt_logprobs_raw: list[LogprobsTensors] | None = None, + eos_token_id: int | None = None, + stop_token_ids: list[int] | None = None, ignore_eos: bool = False, ) -> None: self.num_requests = len(tokens_list) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 16cdc19037ba..014e6eca2e02 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -2,8 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import json from dataclasses import fields from enum import Enum @@ -29,7 +27,9 @@ ) if TYPE_CHECKING: - from vllm.config import TokenizerMode + from vllm.config.model import TokenizerMode +else: + TokenizerMode = str NGRAM_SPEC_CONFIG = { "model": "[ngram]", @@ -864,3 +864,49 @@ def test_structured_output_batched_with_non_structured_outputs_requests( # non-structured outputs requests should not return a valid JSON here with pytest.raises(ValueError): output_json = json.loads(generated_text) + + +@pytest.mark.parametrize("guided_decoding_backend", ["xgrammar"]) +def test_structured_output_with_structural_tag( + monkeypatch: pytest.MonkeyPatch, + guided_decoding_backend: str, +): + monkeypatch.setenv("VLLM_USE_V1", "1") + + llm = LLM( + model="Qwen/Qwen2.5-1.5B-Instruct", + guided_decoding_backend=guided_decoding_backend, + ) + + structural_tag_config = { + "type": "structural_tag", + "format": { + "type": "triggered_tags", + "tags": [ + {"begin": "hello_flag", "content": {"type": "any_text"}, "end": "hello"} + ], + "triggers": ["hello"], + "stop_after_first": False, + }, + } + + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=500, + guided_decoding=StructuredOutputsParams( + structural_tag=json.dumps(structural_tag_config) + ), + ) + + prompt = "Hello and repete hello 10 times, do not say anything else. Only say hello hello hello, now start" + outputs = llm.generate(prompt, sampling_params=sampling_params, use_tqdm=True) + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + assert generated_text is not None + assert "hello_flag" in generated_text, ( + f"Expected 'hello_flag' to be in generated text, but got: {generated_text}" + ) diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index 35287f5b979a..736ccbefbc4d 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -1,13 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import openai # use the official client for correctness check import pytest import pytest_asyncio import regex as re -import requests from openai import BadRequestError from tests.utils import RemoteOpenAIServer @@ -195,7 +193,7 @@ async def test_too_many_completion_logprobs( [(MODEL_NAME, -1), (MODEL_NAME, 0), (MODEL_NAME, 1), (MODEL_NAME, None)], ) async def test_prompt_logprobs_completion( - client: openai.AsyncOpenAI, model_name: str, prompt_logprobs: Optional[int] + client: openai.AsyncOpenAI, model_name: str, prompt_logprobs: int | None ): params: dict = { "prompt": ["A robot may not injure another robot", "My name is"], @@ -420,7 +418,7 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI, model_name: assert chunk.usage is None else: assert chunk.usage is None - final_chunk = await stream.__anext__() + final_chunk = await anext(stream) assert final_chunk.usage is not None assert final_chunk.usage.prompt_tokens > 0 assert final_chunk.usage.completion_tokens > 0 @@ -450,7 +448,7 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI, model_name: chunk.usage.prompt_tokens + chunk.usage.completion_tokens ) if chunk.choices[0].finish_reason is not None: - final_chunk = await stream.__anext__() + final_chunk = await anext(stream) assert final_chunk.usage is not None assert final_chunk.usage.prompt_tokens > 0 assert final_chunk.usage.completion_tokens > 0 @@ -687,17 +685,3 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): "structured_outputs": {"grammar": invalid_simplified_sql_grammar} }, ) - - -@pytest.mark.asyncio -async def test_completion_with_empty_prompt_embeds(client: openai.AsyncOpenAI) -> None: - """Test completion with empty prompt embeds.""" - payload: dict[str, object] = {"prompt": "Hello", "prompt_embeds": []} - headers: dict[str, str] = {"Content-Type": "application/json"} - # base_url = http://localhost:8000/v1/completions - response = requests.post( - f"{client.base_url}completions", headers=headers, json=payload - ) - assert response.status_code == 200, ( - f"Expected status code 200, got {response.status_code}. " - ) diff --git a/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py b/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py index 3c2b3de33958..276de2ff8e2c 100644 --- a/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py +++ b/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py @@ -32,6 +32,7 @@ def default_image_embeds_server_args() -> list[str]: "--enforce-eager", "--limit-mm-per-prompt", json.dumps({"image": MAXIMUM_IMAGES}), + "--enable-mm-embeds", ] diff --git a/tests/v1/entrypoints/openai/test_multi_api_servers.py b/tests/v1/entrypoints/openai/test_multi_api_servers.py index 55328f0cf0f0..db52aef70f60 100644 --- a/tests/v1/entrypoints/openai/test_multi_api_servers.py +++ b/tests/v1/entrypoints/openai/test_multi_api_servers.py @@ -10,7 +10,7 @@ from tests.utils import RemoteOpenAIServer from tests.v1.utils import check_request_balancing -MODEL_NAME = "ibm-research/PowerMoE-3b" +MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" DP_SIZE = os.getenv("DP_SIZE", "1") diff --git a/tests/v1/executor/test_executor.py b/tests/v1/executor/test_executor.py index c8bcd62d6680..7293ad09a717 100644 --- a/tests/v1/executor/test_executor.py +++ b/tests/v1/executor/test_executor.py @@ -3,7 +3,8 @@ import asyncio import os -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any import pytest @@ -20,12 +21,12 @@ class Mock: ... class CustomMultiprocExecutor(MultiprocExecutor): def collective_rpc( self, - method: Union[str, Callable], - timeout: Optional[float] = None, + method: str | Callable, + timeout: float | None = None, args: tuple = (), - kwargs: Optional[dict] = None, + kwargs: dict | None = None, non_block: bool = False, - unique_reply_rank: Optional[int] = None, + unique_reply_rank: int | None = None, ) -> list[Any]: # Drop marker to show that this was run with open(".marker", "w"): diff --git a/tests/v1/generation/test_batch_invariance.py b/tests/v1/generation/test_batch_invariance.py index db1c757521f0..8e59b695ed57 100644 --- a/tests/v1/generation/test_batch_invariance.py +++ b/tests/v1/generation/test_batch_invariance.py @@ -3,56 +3,78 @@ import contextlib import os import random -import string import pytest import torch from vllm import LLM, SamplingParams +from vllm.platforms import current_platform + +skip_unsupported = pytest.mark.skipif( + not (current_platform.is_cuda() and current_platform.has_device_capability(90)), + reason="Requires CUDA and >= Hopper (SM90)", +) + + +@pytest.fixture(autouse=True) +def enable_batch_invariant_mode(): + """Automatically enable batch invariant kernel overrides for all tests.""" + old_value = os.environ.get("VLLM_BATCH_INVARIANT") + os.environ["VLLM_BATCH_INVARIANT"] = "1" + yield + # Restore original value after test + if old_value is None: + os.environ.pop("VLLM_BATCH_INVARIANT", None) + else: + os.environ["VLLM_BATCH_INVARIANT"] = old_value def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: - # Lightweight random prompt generator to vary prompt lengths and content. - vocab = [ - "alpha", - "bravo", - "charlie", - "delta", - "echo", - "foxtrot", - "golf", - "hotel", - "india", - "juliet", - "kilo", - "lima", - "mike", - "november", - "oscar", - "papa", - "quebec", - "romeo", - "sierra", - "tango", - "uniform", - "victor", - "whiskey", - "xray", - "yankee", - "zulu", + # Generate more realistic prompts that will actually produce varied tokens + # Use a mix of common English text patterns + + prompt_templates = [ + # Question-answer style + "Question: What is the capital of France?\nAnswer: The capital of France is", + "Q: How does photosynthesis work?\nA: Photosynthesis is the process by which", + "User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is", + # Story/narrative style + "Once upon a time in a distant galaxy, there lived", + "The old man walked slowly down the street, remembering", + "In the year 2157, humanity finally discovered", + # Technical/code style + "To implement a binary search tree in Python, first we need to", + "The algorithm works by iterating through the array and", + "Here's how to optimize database queries using indexing:", + # Factual/informative style + "The Renaissance was a period in European history that", + "Climate change is caused by several factors including", + "The human brain contains approximately 86 billion neurons which", + # Conversational style + "I've been thinking about getting a new laptop because", + "Yesterday I went to the store and bought", + "My favorite thing about summer is definitely", ] - n = random.randint(min_words, max_words) - words = random.choices(vocab, k=n) - # Add some noise and punctuation variability - if random.random() < 0.5: - words[0] = words[0].capitalize() - if random.random() < 0.2: - words.append("".join(random.choices(string.ascii_lowercase, k=5))) - punct = random.choice([".", "?", "!", "...", ""]) - return " ".join(words) + punct + # Pick a random template + base_prompt = random.choice(prompt_templates) + + if max_words < min_words: + max_words = min_words + target_words = random.randint(min_words, max_words) + + if target_words > 50: + # For longer prompts, repeat context + padding_text = ( + " This is an interesting topic that deserves more explanation. " + * (target_words // 50) + ) + base_prompt = base_prompt + padding_text + return base_prompt + +@skip_unsupported @pytest.mark.timeout(1000) def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): """ @@ -76,19 +98,21 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): seed. - Keep max_tokens and max_model_len bounded for speed and memory use. """ - random.seed(12345) + seed = int(os.getenv("VLLM_TEST_SEED", "12345")) + random.seed(seed) # Allow overrides from environment (useful for CI tuning) # "facebook/opt-125m" is too small, doesn't reliably test determinism model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5")) - batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "64")) - assert batch_size >= 2, "Batch size should be >= 2 to mix needle." + max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128")) + min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024")) + max_random_prompt = int(os.getenv("VLLM_MAX_PROMPT", "2048")) + assert max_batch_size >= 2, "Batch size should be >= 2 to mix needle." # Keep GPU memory usage low to avoid startup allocation failures. - gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.3")) - max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096")) - swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4")) + gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4")) + max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120")) # Sampling parameters: longer outputs with a more random-sounding # continuation,but still deterministic due to fixed seed. @@ -111,10 +135,9 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): # Engine with bs=1 behavior llm_bs1 = LLM_with_max_seqs( model=model, - max_num_seqs=1, + max_num_seqs=max_batch_size, gpu_memory_utilization=gpu_mem_util, max_model_len=max_model_len, - swap_space=swap_space_gb, ) # Baseline generation for the needle prompt alone. @@ -126,24 +149,24 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): # Engine with larger batch limit (e.g., 64) llm_bsN = LLM_with_max_seqs( model=model, - max_num_seqs=batch_size, + max_num_seqs=max_batch_size, gpu_memory_utilization=gpu_mem_util, max_model_len=max_model_len, - swap_space=swap_space_gb, ) mismatches = 0 for trial in range(num_trials): - # Create a batch of size `batch_size` and insert the needle at + # Create a batch of size `max_batch_size` and insert the needle at # a random index prompts: list[str] = [] + batch_size = random.randint(max_batch_size // 2, max_batch_size) needle_pos = random.randint(0, batch_size - 1) for i in range(batch_size): if i == needle_pos: prompts.append(needle_prompt) else: - prompts.append(_random_prompt()) + prompts.append(_random_prompt(min_random_prompt, max_random_prompt)) # Generate with the larger-batch engine outputs = llm_bsN.generate(prompts, sampling) @@ -154,19 +177,20 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): text = needle_output.outputs[0].text if text != baseline_text: + print(f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n") mismatches += 1 passes = num_trials - mismatches # Dump how many passed vs failed print( f"[determinism] total={num_trials}, passed={passes}, " - f"failed={mismatches}, batch_size={batch_size}" + f"failed={mismatches}, max_batch_size={max_batch_size}" ) if mismatches > 0: pytest.fail( f"Nondeterministic outputs detected: {mismatches} failed out " - f"of {num_trials} trials (batch_size={batch_size})." + f"of {num_trials} trials (max_batch_size={max_batch_size})." ) finally: @@ -190,85 +214,766 @@ def _extract_step_logprobs(request_output): ], dtype=torch.float32, ) - return t + return t, inner.token_ids - return None + return None, None -@pytest.mark.skipif( - not torch.cuda.is_available(), - reason="Requires CUDA to match production inference path.", -) -def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2(): - # model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m") +@skip_unsupported +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"]) +@pytest.mark.forked +def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend): + backend = os.getenv("VLLM_ATTENTION_BACKEND", backend) + os.environ["VLLM_ATTENTION_BACKEND"] = backend + + seed = int(os.getenv("VLLM_TEST_SEED", "12345")) + random.seed(seed) model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) - # Force float32 to avoid precision-induced differences. + # For batch invariance, disable custom all-reduce to ensure deterministic + # all-reduce operations (custom all-reduce may not be deterministic) + from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, + ) + + disable_custom_ar = vllm_is_batch_invariant() + + if disable_custom_ar: + print(f"\n{'=' * 80}") + print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})") + print(f"{'=' * 80}\n") + llm = LLM( model=model_name, tensor_parallel_size=tp_size, - enforce_eager=True, # helps reduce nondeterminism from some backends + enable_prefix_caching=False, + max_num_seqs=32, + max_model_len=8192, + dtype="bfloat16", # not everything is supported ) - prompts = [ - "The capital of France is", - "The capital of Germany is", - ] + # Use more realistic prompts for better token generation + prompts = [_random_prompt(10, 50) for i in range(32)] sp = SamplingParams( - temperature=0.0, + temperature=0.6, top_p=1.0, max_tokens=8, - # Seed shouldn't matter at temperature=0, but keeping it stable anyway. seed=1234, logprobs=5, ) # BS=1: run prompts individually and collect logprobs per step. + print("\n" + "=" * 80) + print("STARTING BS=1 RUNS (each prompt individually)") + print("=" * 80 + "\n") + bs1_logprobs_per_prompt = [] - for p in prompts: + bs1_tokens_per_prompt = [] + for idx, p in enumerate(prompts): + print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...") outs = llm.generate([p], sp, use_tqdm=False) assert len(outs) == 1 - step_logprobs = _extract_step_logprobs(outs[0]) + step_logprobs, token_ids = _extract_step_logprobs(outs[0]) if step_logprobs is None: pytest.skip( "Logits are not available on RequestOutput; " "enable logprobs return to run this test." ) bs1_logprobs_per_prompt.append(step_logprobs) + bs1_tokens_per_prompt.append(token_ids) + print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}") - # BS=2: run prompts in a batch and collect logprobs per step for each + # BS=N: run prompts in a batch and collect logprobs per step for each # prompt. + print("\n" + "=" * 80) + print(f"STARTING BS={len(prompts)} RUN (all prompts batched)") + print("=" * 80 + "\n") + outs_batched = llm.generate(prompts, sp, use_tqdm=False) assert len(outs_batched) == len(prompts) - bs2_logprobs_per_prompt = [] - for o in outs_batched: - step_logprobs = _extract_step_logprobs(o) + bsN_logprobs_per_prompt = [] + bsN_tokens_per_prompt = [] + + print(f"\n[BS={len(prompts)}] Processing batched outputs...") + for idx, o in enumerate(outs_batched): + tokens = o.outputs[0].token_ids if o.outputs else "N/A" + print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}") + step_logprobs, token_ids = _extract_step_logprobs(o) if step_logprobs is None: pytest.skip( "Logits are not available on RequestOutput; " "enable logprobs return to run this test." ) - bs2_logprobs_per_prompt.append(step_logprobs) - - # Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs. - for i, (logprobs_bs1, logprobs_bs2) in enumerate( - zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt) + bsN_logprobs_per_prompt.append(step_logprobs) + bsN_tokens_per_prompt.append(token_ids) + + # Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs. + failed_prompts = [] + for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate( + zip( + bs1_logprobs_per_prompt, + bsN_logprobs_per_prompt, + bs1_tokens_per_prompt, + bsN_tokens_per_prompt, + ) ): - assert len(logprobs_bs1) == len(logprobs_bs2), ( - f"Different number of generation steps for prompt index {i}: " - f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)" + if len(logprobs_bs1) != len(logprobs_bsN): + reason = ( + f"Different number of steps: {len(logprobs_bs1)} (BS=1) " + f"vs {len(logprobs_bsN)} (BS=N)" + ) + failed_prompts.append( + { + "prompt_idx": i, + "step": "all", + "reason": reason, + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + } + ) + continue + + # Check if tokens match first + if tokens_bs1 != tokens_bsN: + failed_prompts.append( + { + "prompt_idx": i, + "step": "sampling", + "reason": "Different tokens sampled", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + "bs1_all_logprobs": [ + logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1)) + ], + "bsN_all_logprobs": [ + logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN)) + ], + } + ) + continue + + for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)): + if a.shape != b.shape: + failed_prompts.append( + { + "prompt_idx": i, + "step": t, + "reason": f"Shape mismatch: {a.shape} vs {b.shape}", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + } + ) + break + + if not torch.equal(a, b): + max_diff = torch.abs(a - b).max().item() + # Print which token failed + print(f"\n[DIVERGENCE] Prompt {i}, Token {t}: max_diff={max_diff:.6e}") + bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A" + bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A" + print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}") + print(f" BS=1 logprob: {a.tolist()}") + print(f" BS=N logprob: {b.tolist()}") + failed_prompts.append( + { + "prompt_idx": i, + "step": t, + "reason": f"Bitwise mismatch (max_diff={max_diff:.6e})", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + "bs1_all_logprobs": [ + logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1)) + ], + "bsN_all_logprobs": [ + logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN)) + ], + } + ) + break + + # Print summary of all failures + if failed_prompts: + print(f"\n{'=' * 80}") + fail_msg = ( + f"BATCH INVARIANCE FAILURES: {len(failed_prompts)}/" + f"{len(prompts)} prompts failed" + ) + print(fail_msg) + print(f"{'=' * 80}") + for fail in failed_prompts: + print(f"\nPrompt {fail['prompt_idx']} (step {fail['step']}):") + print(f" Reason: {fail['reason']}") + print(f" Preview: {fail['prompt_preview']}...") + + # Always show the tokens + if "bs1_tokens" in fail: + print(f" BS=1 tokens: {fail['bs1_tokens']}") + if "bsN_tokens" in fail: + print(f" BS=N tokens: {fail['bsN_tokens']}") + + if "bs1_all_logprobs" in fail: + print(f" BS=1 logprobs for all {len(fail['bs1_all_logprobs'])} steps:") + for step_idx, logprobs in enumerate(fail["bs1_all_logprobs"]): + print(f" Step {step_idx}: {logprobs}") + print(f" BS=N logprobs for all {len(fail['bsN_all_logprobs'])} steps:") + for step_idx, logprobs in enumerate(fail["bsN_all_logprobs"]): + print(f" Step {step_idx}: {logprobs}") + print(f"{'=' * 80}\n") + + # Fail the test with summary + msg = ( + f"Batch invariance violated in {len(failed_prompts)}/" + f"{len(prompts)} prompts. See output above for details." + ) + pytest.fail(msg) + + +@skip_unsupported +def test_simple_generation(): + """ + Simple test that runs the model with a basic prompt and prints the output. + Useful for quick smoke testing and debugging. + """ + model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + + llm = LLM( + model=model, + max_num_seqs=1, + tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), + enforce_eager=True, + gpu_memory_utilization=0.9, + max_model_len=2048, + dtype="bfloat16", + enable_prefix_caching=False, + ) + + prompt = "the capital of france is" + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=20, + ) + + print(f"\n{'=' * 80}") + print("Running simple generation test") + print(f"Prompt: '{prompt}'") + print(f"{'=' * 80}\n") + + try: + outputs = llm.generate([prompt], sampling_params) + + assert len(outputs) == 1 + output_text = outputs[0].outputs[0].text + + print(f"Output: '{output_text}'") + print(f"\n{'=' * 80}") + print(f"Full completion: '{prompt}{output_text}'") + print(f"{'=' * 80}\n") + + finally: + with contextlib.suppress(Exception): + llm.shutdown() + + +@skip_unsupported +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"]) +@pytest.mark.forked +def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend): + """ + This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN. + It DISABLES batch invariance mode and expects to see non-deterministic behavior + between BS=1 and BS=N runs. This demonstrates that batch invariance is actually + doing something useful. + + The test will PASS if we detect differences (proving batch invariance matters). + The test will FAIL if everything matches (suggesting batch invariance isn't needed). + """ + backend = os.getenv("VLLM_ATTENTION_BACKEND", backend) + os.environ["VLLM_ATTENTION_BACKEND"] = backend + + # CRITICAL: Disable batch invariance for this test + old_value = os.environ.get("VLLM_BATCH_INVARIANT") + os.environ["VLLM_BATCH_INVARIANT"] = "0" + + try: + seed = int(os.getenv("VLLM_TEST_SEED", "12345")) + random.seed(seed) + model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) + + print(f"\n{'=' * 80}") + print("BATCH INVARIANCE DISABLED: Expecting non-deterministic behavior") + print(f"{'=' * 80}\n") + + llm = LLM( + model=model_name, + tensor_parallel_size=tp_size, + enable_prefix_caching=False, + max_num_seqs=32, + max_model_len=8192, + dtype="bfloat16", + ) + + # build ragged prompts to change shapes significantly across BS=1 vs BS=N + long_min = int(os.getenv("VLLM_MIN_PROMPT", "768")) + long_max = int(os.getenv("VLLM_MAX_PROMPT", "2048")) + prompts: list[str] = [] + options = [ + (max(long_min, 1536), max(long_max, 3072)), # very long + (max(1024, long_min), max(2048, long_max)), # long + (256, 512), # mid + (10, 20), # short + ] + + for _ in range(32): + lo, hi = random.choice(options) + prompts.append(_random_prompt(lo, hi)) + + sp = SamplingParams( + temperature=0.6, + top_p=1.0, + max_tokens=8, + seed=1234, + logprobs=5, ) - for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)): - assert a.shape == b.shape, ( - f"Logits shape mismatch at prompt {i}, step {t}: {a.shape} vs {b.shape}" + + # BS=1: run prompts individually and collect logprobs per step. + print("\n" + "=" * 80) + print("STARTING BS=1 RUNS (each prompt individually)") + print("=" * 80 + "\n") + + bs1_logprobs_per_prompt = [] + bs1_tokens_per_prompt = [] + for idx, p in enumerate(prompts): + print( + f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..." + ) + outs = llm.generate([p], sp, use_tqdm=False) + assert len(outs) == 1 + step_logprobs, token_ids = _extract_step_logprobs(outs[0]) + if step_logprobs is None: + pytest.skip( + "Logits are not available on RequestOutput; " + "enable logprobs return to run this test." + ) + bs1_logprobs_per_prompt.append(step_logprobs) + bs1_tokens_per_prompt.append(token_ids) + print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}") + + # BS=N: run prompts in a batch and collect logprobs per step for each prompt. + print("\n" + "=" * 80) + print(f"STARTING BS={len(prompts)} RUN (all prompts batched)") + print("=" * 80 + "\n") + + outs_batched = llm.generate(prompts, sp, use_tqdm=False) + assert len(outs_batched) == len(prompts) + bsN_logprobs_per_prompt = [] + bsN_tokens_per_prompt = [] + + print(f"\n[BS={len(prompts)}] Processing batched outputs...") + for idx, o in enumerate(outs_batched): + tokens = o.outputs[0].token_ids if o.outputs else "N/A" + print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}") + step_logprobs, token_ids = _extract_step_logprobs(o) + if step_logprobs is None: + pytest.skip( + "Logits are not available on RequestOutput; " + "enable logprobs return to run this test." + ) + bsN_logprobs_per_prompt.append(step_logprobs) + bsN_tokens_per_prompt.append(token_ids) + + # Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs. + differences_found = [] + for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate( + zip( + bs1_logprobs_per_prompt, + bsN_logprobs_per_prompt, + bs1_tokens_per_prompt, + bsN_tokens_per_prompt, ) - # Bitwise exact equality. - assert torch.equal(a, b), ( - f"Bitwise logprobs mismatch at prompt {i}, step {t} " - f"(dtype={a.dtype}, shape={a.shape})." + ): + if len(logprobs_bs1) != len(logprobs_bsN): + reason = ( + f"Different number of steps: {len(logprobs_bs1)} (BS=1) " + f"vs {len(logprobs_bsN)} (BS=N)" + ) + differences_found.append( + { + "prompt_idx": i, + "step": "all", + "reason": reason, + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + } + ) + continue + + # Check if tokens match first + if tokens_bs1 != tokens_bsN: + differences_found.append( + { + "prompt_idx": i, + "step": "sampling", + "reason": "Different tokens sampled", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + } + ) + continue + + for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)): + if a.shape != b.shape: + differences_found.append( + { + "prompt_idx": i, + "step": t, + "reason": f"Shape mismatch: {a.shape} vs {b.shape}", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + } + ) + break + + if not torch.equal(a, b): + max_diff = torch.abs(a - b).max().item() + print( + f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, " + f"Token {t}: max_diff={max_diff:.6e}" + ) + bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A" + bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A" + print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}") + print(f" BS=1 logprob: {a.tolist()}") + print(f" BS=N logprob: {b.tolist()}") + differences_found.append( + { + "prompt_idx": i, + "step": t, + "reason": f"Bitwise mismatch (max_diff={max_diff:.6e})", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + } + ) + break + + # Print summary + print(f"\n{'=' * 80}") + if differences_found: + success_msg = ( + f"✓ SUCCESS: Batch invariance is doing something! " + f"Found {len(differences_found)}/{len(prompts)} prompts " + f"with differences when batch invariance was DISABLED." ) + print(success_msg) + print(f"{'=' * 80}") + for diff in differences_found: + print(f"\nPrompt {diff['prompt_idx']} (step {diff['step']}):") + print(f" Reason: {diff['reason']}") + print(f" Preview: {diff['prompt_preview']}...") + if "bs1_tokens" in diff: + print(f" BS=1 tokens: {diff['bs1_tokens']}") + if "bsN_tokens" in diff: + print(f" BS=N tokens: {diff['bsN_tokens']}") + print(f"{'=' * 80}\n") + # Test PASSES because we found differences (batch invariance matters!) + return + else: + # Test FAILS because everything matched even without batch invariance + fail_msg = ( + f"✗ UNEXPECTED: All {len(prompts)} prompts matched " + f"between BS=1 and BS=N even with batch invariance DISABLED. " + f"This suggests batch invariance might not be necessary, " + f"or the test needs more sensitive prompts." + ) + print(fail_msg) + print(f"{'=' * 80}\n") + pytest.fail(fail_msg) + + finally: + # Restore original value + if old_value is None: + os.environ.pop("VLLM_BATCH_INVARIANT", None) + else: + os.environ["VLLM_BATCH_INVARIANT"] = old_value + + +@skip_unsupported +@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) +@pytest.mark.forked +def test_decode_logprobs_match_prefill_logprobs(backend): + """ + Test that verifies decode logprobs match prefill logprobs. + + For each decoded token at position i: + 1. Run decode to generate N tokens and collect their logprobs + 2. For each position i in [0, N): + - Take prefix = prompt + tokens[0:i] + - Run prefill(prefix + tokens[i]) to get logprob of tokens[i] + - Verify prefill logprob matches decode logprob bitwise + + This ensures that the logprobs from decode are consistent with what + we would get if we ran prefill on each prefix. + """ + backend = os.getenv("VLLM_ATTENTION_BACKEND", backend) + os.environ["VLLM_ATTENTION_BACKEND"] = backend + + seed = int(os.getenv("VLLM_TEST_SEED", "12345")) + random.seed(seed) + model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) + + from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, + ) + + disable_custom_ar = vllm_is_batch_invariant() + + if disable_custom_ar: + print(f"\n{'=' * 80}") + print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})") + print(f"{'=' * 80}\n") + + llm = LLM( + model=model_name, + tensor_parallel_size=tp_size, + enable_prefix_caching=False, + max_num_seqs=32, + max_model_len=8192, + dtype="bfloat16", + ) + + # Use a few test prompts + num_test_prompts = int(os.getenv("VLLM_DECODE_PREFILL_NUM_PROMPTS", "4")) + prompts = [_random_prompt(10, 50) for _ in range(num_test_prompts)] + + # Generate longer sequences to test multiple decode steps + max_tokens = int(os.getenv("VLLM_DECODE_PREFILL_MAX_TOKENS", "16")) + + sp = SamplingParams( + temperature=0.0, # Greedy for determinism + max_tokens=max_tokens, + logprobs=5, + ) + + print("\n" + "=" * 80) + print("STEP 1: Running decode to generate tokens and collect logprobs") + print("=" * 80 + "\n") + + # Step 1: Run decode and collect logprobs + decode_outputs = llm.generate(prompts, sp, use_tqdm=False) + + failed_comparisons = [] + + for prompt_idx, (prompt, decode_output) in enumerate(zip(prompts, decode_outputs)): + print(f"\n[Prompt {prompt_idx}] Testing: {prompt[:80]}...") + + # Extract decode logprobs and tokens + decode_logprobs, token_ids = _extract_step_logprobs(decode_output) + if decode_logprobs is None: + pytest.skip( + "Logprobs are not available on RequestOutput; " + "enable logprobs return to run this test." + ) + + print(f"[Prompt {prompt_idx}] Generated {len(token_ids)} tokens: {token_ids}") + print(f"[Prompt {prompt_idx}] Decode logprobs: {decode_logprobs.tolist()}") + + # Step 2: For each token position, run prefill and compare + print(f"\n[Prompt {prompt_idx}] Verifying each token via prefill...") + + for token_idx in range(len(token_ids)): + # Construct the prefix up to (but not including) this token + current_token = token_ids[token_idx] + + # We need to detokenize to get the text prefix + # For this, we'll use the tokenizer from the LLM + # However, the LLM API doesn't expose tokenizer easily, so we'll + # construct the prefix by decoding from the original prompt + + # Get text up to this point by using the output text + # This is approximate but should work for verification + if token_idx == 0: + prefix_prompt = prompt + else: + # Use the partial output text up to this token + # We'll need to construct this from the full output + prefix_output = decode_output.outputs[0] + # Get the text for tokens 0 to token_idx-1 + # Unfortunately, we don't have per-token text, so we'll use + # a different approach: run prefill with prompt + tokens[0:token_idx] + + # Actually, we need to get the actual text. Let's use a workaround: + # Run a generation with max_tokens = token_idx to get that prefix + prefix_sp = SamplingParams( + temperature=0.0, + max_tokens=token_idx, + logprobs=1, + ) + prefix_output = llm.generate([prompt], prefix_sp, use_tqdm=False)[0] + prefix_prompt = prompt + prefix_output.outputs[0].text + + # Now run prefill with max_tokens=1 to get the logprob of the next token + prefill_sp = SamplingParams( + temperature=0.0, + max_tokens=1, + logprobs=5, + ) + + print( + f" [Token {token_idx}] Running prefill for prefix " + f"(len={len(prefix_prompt)})..." + ) + prefill_output = llm.generate([prefix_prompt], prefill_sp, use_tqdm=False)[ + 0 + ] + prefill_logprobs, prefill_token_ids = _extract_step_logprobs(prefill_output) + + if prefill_logprobs is None: + print(f" [Token {token_idx}] Warning: No prefill logprobs available") + continue + + # The first token from prefill should match the current token + prefill_token = prefill_token_ids[0] + prefill_logprob = prefill_logprobs[0].item() + decode_logprob = decode_logprobs[token_idx].item() + + print( + f" [Token {token_idx}] Decode token: {current_token}, " + f"logprob: {decode_logprob:.8f}" + ) + print( + f" [Token {token_idx}] Prefill token: {prefill_token}, " + f"logprob: {prefill_logprob:.8f}" + ) + + # Check if tokens match + if current_token != prefill_token: + failed_comparisons.append( + { + "prompt_idx": prompt_idx, + "token_idx": token_idx, + "reason": "Token mismatch", + "decode_token": current_token, + "prefill_token": prefill_token, + "decode_logprob": decode_logprob, + "prefill_logprob": prefill_logprob, + "prompt_text": prompt[:100], + "prefix_text": prefix_prompt[:100], + } + ) + print(f" [Token {token_idx}] ✗ TOKEN MISMATCH!") + continue + + # Check if logprobs match bitwise + if decode_logprob != prefill_logprob: + diff = abs(decode_logprob - prefill_logprob) + failed_comparisons.append( + { + "prompt_idx": prompt_idx, + "token_idx": token_idx, + "reason": "Logprob mismatch", + "decode_token": current_token, + "prefill_token": prefill_token, + "decode_logprob": decode_logprob, + "prefill_logprob": prefill_logprob, + "diff": diff, + "prompt_text": prompt[:100], + "prefix_text": prefix_prompt[:100], + "decode_all_tokens": token_ids, + "decode_all_logprobs": decode_logprobs.tolist(), + } + ) + print(f" [Token {token_idx}] ✗ LOGPROB MISMATCH! diff={diff:.8e}") + else: + print(f" [Token {token_idx}] ✓ Match (bitwise equal)") + + # Print summary + print(f"\n{'=' * 80}") + if failed_comparisons: + print(f"DECODE-PREFILL MISMATCH: {len(failed_comparisons)} failures detected") + print(f"{'=' * 80}") + + # Group failures by prompt for better readability + failures_by_prompt: dict[int, list[dict]] = {} + for fail in failed_comparisons: + pid = fail["prompt_idx"] + if pid not in failures_by_prompt: + failures_by_prompt[pid] = [] + failures_by_prompt[pid].append(fail) + + for prompt_idx, failures in failures_by_prompt.items(): + print(f"\n{'=' * 80}") + print(f"PROMPT {prompt_idx}: {failures[0]['prompt_text']}...") + print(f"{'=' * 80}") + print(f"Total failures for this prompt: {len(failures)}") + + # Show where mismatches occur (which token positions) + mismatch_positions = [f["token_idx"] for f in failures] + print(f"Mismatch at token positions: {mismatch_positions}") + + # Show first few failures in detail + for i, fail in enumerate(failures[:5]): # Show first 5 failures per prompt + print(f"\n [Failure {i + 1}] Token position {fail['token_idx']}:") + print(f" Reason: {fail['reason']}") + print(f" Prefix text: '{fail['prefix_text']}...'") + print( + f" Decode: token={fail['decode_token']}, " + f"logprob={fail['decode_logprob']:.10f}" + ) + print( + f" Prefill: token={fail['prefill_token']}, " + f"logprob={fail['prefill_logprob']:.10f}" + ) + if "diff" in fail: + print(f" Difference: {fail['diff']:.10e}") + # Show in hex to see bitwise difference + import struct + + decode_hex = struct.pack("f", fail["decode_logprob"]).hex() + prefill_hex = struct.pack("f", fail["prefill_logprob"]).hex() + print(f" Decode logprob (hex): 0x{decode_hex}") + print(f" Prefill logprob (hex): 0x{prefill_hex}") + + # If we have all tokens/logprobs, show the context + if "decode_all_tokens" in fail and "decode_all_logprobs" in fail: + token_idx = fail["token_idx"] + all_tokens = fail["decode_all_tokens"] + all_logprobs = fail["decode_all_logprobs"] + + # Show context: 2 tokens before and after + start = max(0, token_idx - 2) + end = min(len(all_tokens), token_idx + 3) + + print(f" Context (tokens {start} to {end - 1}):") + for j in range(start, end): + marker = " <-- MISMATCH" if j == token_idx else "" + print( + f" [{j}] token={all_tokens[j]}, " + f"logprob={all_logprobs[j]:.8f}{marker}" + ) + + if len(failures) > 5: + print(f"\n ... and {len(failures) - 5} more failures for this prompt") + + print(f"\n{'=' * 80}\n") + + pytest.fail( + f"Decode logprobs do not match prefill logprobs: " + f"{len(failed_comparisons)} mismatches found." + ) + else: + print("✓ SUCCESS: All decode logprobs match prefill logprobs bitwise!") + print(f"{'=' * 80}\n") def LLM_with_max_seqs( @@ -276,7 +981,6 @@ def LLM_with_max_seqs( max_num_seqs: int, gpu_memory_utilization: float, max_model_len: int, - swap_space: int, ) -> LLM: """ Helper to construct an LLM with a specific max_num_seqs (batch-size limit) @@ -285,15 +989,12 @@ def LLM_with_max_seqs( return LLM( model=model, max_num_seqs=max_num_seqs, - # Constrain GPU memory pool so test can run even on busy GPUs. gpu_memory_utilization=gpu_memory_utilization, - # Keep KV cache footprint small while allowing longer outputs. max_model_len=max_model_len, - # Allow some CPU offload if needed. - swap_space=swap_space, - # Keep things lean and CI-friendly. - dtype="float16", - # Single-GPU by default; override externally if desired. + dtype="bfloat16", tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), - trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1", + enable_prefix_caching=False, + enforce_eager=True, + # Enable for MOE models + # enable_expert_parallel=True, ) diff --git a/tests/v1/generation/test_rms_norm_batch_invariant.py b/tests/v1/generation/test_rms_norm_batch_invariant.py new file mode 100644 index 000000000000..f79eba58d6ef --- /dev/null +++ b/tests/v1/generation/test_rms_norm_batch_invariant.py @@ -0,0 +1,315 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Test batch-invariant RMS normalization against standard implementations. + +This test compares the Triton-based batch-invariant RMS norm implementation +with the standard CUDA-based implementation to ensure numerical accuracy. +""" + +import pytest +import torch + +from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.platforms import current_platform + +skip_unsupported = pytest.mark.skipif( + not (current_platform.is_cuda() and current_platform.has_device_capability(90)), + reason="Requires CUDA and >= Hopper (SM90)", +) + + +@skip_unsupported +@pytest.mark.parametrize("batch_size", [1, 4, 16, 64]) +@pytest.mark.parametrize("hidden_size", [512, 2048, 4096, 8192]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("eps", [1e-6, 1e-5]) +def test_rms_norm_batch_invariant_vs_standard( + batch_size: int, hidden_size: int, dtype: torch.dtype, eps: float +): + """ + Compare batch-invariant Triton RMS norm against standard CUDA implementation. + + Tests that the Triton-based batch-invariant RMS norm produces numerically + equivalent results to the standard CUDA implementation across various + configurations. + """ + device = torch.device("cuda") + + # Create test input and weight + torch.manual_seed(42) + input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + + # Standard implementation (CUDA ops) + rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device) + rms_norm_layer.weight.data = weight.clone() + + standard_output = rms_norm_layer.forward_cuda(input_tensor) + + # Batch-invariant implementation (Triton) + triton_output = triton_rms_norm(input_tensor, weight, eps=eps) + + # Compare outputs + # Use looser tolerance for bfloat16 due to its lower precision + if dtype == torch.bfloat16: + rtol, atol = 1e-1, 1e-1 # 10% relative tolerance for bfloat16 + else: + rtol, atol = 1e-2, 1e-2 # 1% for float16/float32 + + torch.testing.assert_close( + triton_output, + standard_output, + rtol=rtol, + atol=atol, + msg=f"RMS norm mismatch for batch_size={batch_size}, " + f"hidden_size={hidden_size}, " + f"dtype={dtype}, eps={eps}", + ) + + +@skip_unsupported +@pytest.mark.parametrize("batch_size", [1, 16, 128]) +@pytest.mark.parametrize("seq_len", [1, 32, 512]) +@pytest.mark.parametrize("hidden_size", [2048, 4096]) +def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int): + """ + Test RMS norm with 3D input tensors (batch, seq_len, hidden_size). + + Ensures that the batch-invariant RMS norm correctly handles multi-dimensional + inputs that are common in transformer models. + """ + device = torch.device("cuda") + dtype = torch.bfloat16 + eps = 1e-6 + + torch.manual_seed(42) + input_tensor = torch.randn( + batch_size, seq_len, hidden_size, dtype=dtype, device=device + ) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + + # Standard implementation + rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device) + rms_norm_layer.weight.data = weight.clone() + standard_output = rms_norm_layer.forward_cuda(input_tensor) + + # Batch-invariant implementation + triton_output = triton_rms_norm(input_tensor, weight, eps=eps) + + # Use looser tolerance for bfloat16 + rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16 + + torch.testing.assert_close( + triton_output, + standard_output, + rtol=rtol, + atol=atol, + msg=f"RMS norm mismatch for 3D input with batch_size={batch_size}, " + f"seq_len={seq_len}, hidden_size={hidden_size}", + ) + + +@skip_unsupported +def test_rms_norm_numerical_stability(): + """ + Test RMS norm numerical stability with extreme values. + + Ensures that both implementations handle edge cases like very small or large + values without producing NaN or Inf. + """ + device = torch.device("cuda") + dtype = torch.float16 + eps = 1e-6 + hidden_size = 2048 + + # Test cases with extreme values + test_cases = [ + # Very small values + torch.ones(4, hidden_size, dtype=dtype, device=device) * 1e-5, + # Very large values + torch.ones(4, hidden_size, dtype=dtype, device=device) * 1e4, + # Mixed small and large + torch.randn(4, hidden_size, dtype=dtype, device=device) * 100, + # Values near zero + torch.randn(4, hidden_size, dtype=dtype, device=device) * 1e-6, + ] + + weight = torch.ones(hidden_size, dtype=dtype, device=device) + + for idx, input_tensor in enumerate(test_cases): + # Standard implementation + rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device) + rms_norm_layer.weight.data = weight.clone() + standard_output = rms_norm_layer.forward_cuda(input_tensor) + + # Batch-invariant implementation + triton_output = triton_rms_norm(input_tensor, weight, eps=eps) + + # Check for NaN or Inf + assert not torch.isnan(standard_output).any(), ( + f"Standard RMS norm produced NaN for test case {idx}" + ) + assert not torch.isinf(standard_output).any(), ( + f"Standard RMS norm produced Inf for test case {idx}" + ) + assert not torch.isnan(triton_output).any(), ( + f"Triton RMS norm produced NaN for test case {idx}" + ) + assert not torch.isinf(triton_output).any(), ( + f"Triton RMS norm produced Inf for test case {idx}" + ) + + # Compare outputs - very lenient for extreme values with float16 + torch.testing.assert_close( + triton_output, + standard_output, + rtol=2e-1, # 20% tolerance for extreme values + atol=2e-1, + msg=f"RMS norm mismatch for extreme value test case {idx}", + ) + + +@skip_unsupported +def test_rms_norm_formula(): + """ + Test that RMS norm follows the correct mathematical formula. + + Verifies: output = input / sqrt(mean(input^2) + eps) * weight + """ + device = torch.device("cuda") + dtype = torch.float32 # Use float32 for higher precision in formula check + eps = 1e-6 + hidden_size = 1024 + + torch.manual_seed(42) + input_tensor = torch.randn(8, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + + # Compute expected output using the formula + variance = (input_tensor.pow(2).mean(dim=-1, keepdim=True)).to(dtype) + expected_output = input_tensor * torch.rsqrt(variance + eps) * weight + + # Batch-invariant implementation + triton_output = triton_rms_norm(input_tensor, weight, eps=eps) + + # Compare against formula + torch.testing.assert_close( + triton_output, + expected_output, + rtol=1e-4, + atol=1e-4, + msg="Triton RMS norm doesn't match expected formula", + ) + + +@skip_unsupported +@pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384]) +def test_rms_norm_different_hidden_sizes(hidden_size: int): + """ + Test RMS norm with various hidden sizes to ensure block size handling. + + The Triton kernel uses a fixed BLOCK_SIZE=1024, so this tests that it + correctly handles hidden sizes both smaller and larger than the block size. + """ + device = torch.device("cuda") + dtype = torch.bfloat16 + eps = 1e-6 + batch_size = 16 + + torch.manual_seed(42) + input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + + # Standard implementation + rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device) + rms_norm_layer.weight.data = weight.clone() + standard_output = rms_norm_layer.forward_cuda(input_tensor) + + # Batch-invariant implementation + triton_output = triton_rms_norm(input_tensor, weight, eps=eps) + + # Use looser tolerance for bfloat16 + rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16 + + torch.testing.assert_close( + triton_output, + standard_output, + rtol=rtol, + atol=atol, + msg=f"RMS norm mismatch for hidden_size={hidden_size}", + ) + + +@skip_unsupported +def test_rms_norm_determinism(): + """ + Test that batch-invariant RMS norm produces deterministic results. + + Runs the same input through the kernel multiple times and verifies + identical outputs. + """ + device = torch.device("cuda") + dtype = torch.bfloat16 + eps = 1e-6 + hidden_size = 4096 + batch_size = 32 + + torch.manual_seed(42) + input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + + # Run multiple times + outputs = [] + for _ in range(5): + output = triton_rms_norm(input_tensor.clone(), weight, eps=eps) + outputs.append(output) + + # All outputs should be identical + reference = outputs[0] + for idx, output in enumerate(outputs[1:], start=1): + torch.testing.assert_close( + output, + reference, + rtol=0.0, + atol=0.0, + msg=f"RMS norm not deterministic: run {idx} differs from reference", + ) + + +if __name__ == "__main__": + # Run a quick smoke test + print("Running quick smoke test of RMS norm implementations...") + + device = torch.device("cuda") + batch_size = 8 + hidden_size = 4096 + dtype = torch.bfloat16 + eps = 1e-6 + + torch.manual_seed(42) + input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + + # Standard implementation + rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device) + rms_norm_layer.weight.data = weight.clone() + standard_output = rms_norm_layer.forward_cuda(input_tensor) + + # Batch-invariant implementation + triton_output = triton_rms_norm(input_tensor, weight, eps=eps) + + # Compare + max_diff = (triton_output - standard_output).abs().max().item() + mean_diff = (triton_output - standard_output).abs().mean().item() + + print(f"Max difference: {max_diff:.6e}") + print(f"Mean difference: {mean_diff:.6e}") + print(f"Standard output sample: {standard_output[0, :5].tolist()}") + print(f"Triton output sample: {triton_output[0, :5].tolist()}") + + if max_diff < 1e-3: + print("✓ Smoke test passed!") + else: + print("✗ Smoke test failed - differences too large") diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index 3b0f2d102c1f..a756858e2cc5 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -19,23 +19,36 @@ done echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE" +DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD +if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then + KV_CONFIG_HETERO_LAYOUT=',"enable_permute_local_kv":"True"' +else + KV_CONFIG_HETERO_LAYOUT='' +fi + # Build the kv-transfer-config once if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then - KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"}' + KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}' else - KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\"}" + KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}" fi # Models to run -MODELS=( - "Qwen/Qwen3-0.6B" -) +MODEL_NAMES=${MODEL_NAMES:-} +if [[ -n "$MODEL_NAMES" ]]; then + MODELS=("$MODEL_NAMES") +else + MODELS=( + "Qwen/Qwen3-0.6B" + ) +fi # Number of prefill and decode instances to create NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1 NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1 PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1} DECODER_TP_SIZE=${DECODER_TP_SIZE:-1} +GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2} # Find the git repository root directory GIT_ROOT=$(git rev-parse --show-toplevel) @@ -101,6 +114,12 @@ run_tests_for_model() { for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do # Calculate GPU ID - we'll distribute across available GPUs GPU_ID=$((i % $(get_num_gpus))) + NEXT_GPU=${GPU_ID} + # If PREFILLER_TP_SIZE is more than 1 + for (( j=1; j < PREFILLER_TP_SIZE; j++ )); do + NEXT_GPU=$(((GPU_ID + j) % $(get_num_gpus))) + GPU_ID="${GPU_ID},${NEXT_GPU}" + done # Calculate port number (base port + instance number) PORT=$((8100 + i)) @@ -111,12 +130,14 @@ run_tests_for_model() { # Build the command with or without model-specific args BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \ + VLLM_KV_CACHE_LAYOUT='HND' \ UCX_NET_DEVICES=all \ VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \ vllm serve $model_name \ --port $PORT \ --enforce-eager \ - --gpu-memory-utilization 0.2 \ + --disable-hybrid-kv-cache-manager \ + --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --tensor-parallel-size $PREFILLER_TP_SIZE \ --kv-transfer-config '$KV_CONFIG'" @@ -136,7 +157,12 @@ run_tests_for_model() { # Start decode instances for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do # Calculate GPU ID - we'll distribute across available GPUs, starting from after prefill GPUs - GPU_ID=$(((i + NUM_PREFILL_INSTANCES) % $(get_num_gpus))) + GPU_ID=$(((i + NEXT_GPU + 1) % $(get_num_gpus))) + # If DECODER_TP_SIZE is more than 1 + for (( j=1; j < DECODER_TP_SIZE; j++ )); do + NEXT_GPU=$(((GPU_ID + j) % $(get_num_gpus))) + GPU_ID="${GPU_ID},${NEXT_GPU}" + done # Calculate port number (base port + instance number) PORT=$((8200 + i)) # Calculate side channel port @@ -146,14 +172,24 @@ run_tests_for_model() { # Build the command with or without model-specific args BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \ + VLLM_KV_CACHE_LAYOUT=$DECODER_KV_LAYOUT \ UCX_NET_DEVICES=all \ VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \ vllm serve $model_name \ --port $PORT \ --enforce-eager \ - --gpu-memory-utilization 0.2 \ - --tensor-parallel-size $DECODER_TP_SIZE \ + --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ + --disable-hybrid-kv-cache-manager \ --kv-transfer-config '$KV_CONFIG'" + + # DP-EP attention mode + if [[ -z "$DP_EP" ]]; then + BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE" + else + echo "DP-EP Attention enabled, deploying with dp=DECODER_TP_SIZE and tp=1" + BASE_CMD="${BASE_CMD} --data-parallel-size $DECODER_TP_SIZE \ + --tensor-parallel-size 1 --enable-expert-parallel" + fi if [ -n "$model_args" ]; then FULL_CMD="$BASE_CMD $model_args" @@ -180,7 +216,7 @@ run_tests_for_model() { done # Build the command for the proxy server with all the hosts and ports - PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192" + PROXY_CMD="python3 ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192" # Add all prefill hosts and ports PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}" @@ -199,7 +235,7 @@ run_tests_for_model() { # Run lm eval for this model echo "Running tests for $model_name" - TEST_MODEL=$model_name python -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py + TEST_MODEL=$model_name python3 -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py # Clean up before running next model cleanup_instances diff --git a/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh b/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh index c48b452e24cd..a3eeedb2e514 100755 --- a/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh @@ -85,6 +85,7 @@ run_tests_for_model() { --port $PREFILL_PORT \ --enforce-eager \ --gpu-memory-utilization 0.2 \ + --disable-hybrid-kv-cache-manager \ --kv-transfer-config '$KV_CONFIG'" if [ -n "$model_args" ]; then @@ -103,6 +104,7 @@ run_tests_for_model() { --port $DECODE_PORT \ --enforce-eager \ --gpu-memory-utilization 0.2 \ + --disable-hybrid-kv-cache-manager \ --kv-transfer-config '$KV_CONFIG'" if [ -n "$model_args" ]; then diff --git a/tests/v1/kv_connector/nixl_integration/test_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_accuracy.py index b301968e5bf8..a70f4caeb937 100644 --- a/tests/v1/kv_connector/nixl_integration/test_accuracy.py +++ b/tests/v1/kv_connector/nixl_integration/test_accuracy.py @@ -12,7 +12,12 @@ RTOL = 0.03 # Model-specific expected values -EXPECTED_VALUES = {"Qwen/Qwen3-0.6B": 0.41, "deepseek-ai/deepseek-vl2-small": 0.59} +EXPECTED_VALUES = { + "Qwen/Qwen3-0.6B": 0.41, + "deepseek-ai/deepseek-vl2-small": 0.59, + "deepseek-ai/deepseek-vl2-tiny": 0.19, + "deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65, +} SIMPLE_PROMPT = ( "The best part about working on vLLM is that I got to meet so many people across " diff --git a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py index 37d70510fe25..5768fcdb57ce 100644 --- a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py @@ -76,7 +76,8 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--host", type=str, default="localhost") + # Always use 127.0.0.1 as localhost binds to IPv6 which is blocked on CI + parser.add_argument("--host", type=str, default="127.0.0.1") # For prefiller instances parser.add_argument( diff --git a/tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh new file mode 100755 index 000000000000..9308c81da063 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Utility to run integration tests sequentially with varying TP configurations. +SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh" + +# Define test configurations +configs=( + "GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2" + "GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2" + "GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case + "GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" + "DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1) +) + +run_tests() { + local label=$1 + local extra_env=$2 + + echo "=== Running tests (${label}) ===" + for cfg in "${configs[@]}"; do + echo "-> Running with ${cfg} ${extra_env:+and ${extra_env}}" + # Use 'env' to safely set variables without eval + if ! env ${extra_env} ${cfg} bash "${SCRIPT}"; then + echo "❌ Test failed for config: ${cfg} ${extra_env:+(${extra_env})}" + exit 1 + fi + done + echo "✅ All ${label} tests passed!" +} + +# Run tests +run_tests "default backend" "" + +# Check if FLASHINFER is set (non-empty) +if [[ -n "${FLASHINFER:-}" ]]; then + echo "FLASHINFER is set, rerunning with VLLM_ATTENTION_BACKEND=FLASHINFER" + run_tests "FLASHINFER backend" "VLLM_ATTENTION_BACKEND=FLASHINFER" +else + echo "FLASHINFER not set, skipping FLASHINFER runs." +fi diff --git a/tests/v1/kv_connector/unit/test_decode_bench_connector.py b/tests/v1/kv_connector/unit/test_decode_bench_connector.py new file mode 100644 index 000000000000..24802317a2bb --- /dev/null +++ b/tests/v1/kv_connector/unit/test_decode_bench_connector.py @@ -0,0 +1,415 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for DecodeBenchConnector. + +Tests the functionality of the DecodeBenchConnector which fills KV cache +with dummy values for decode performance benchmarking. +""" + +import pytest +import torch + +from vllm import SamplingParams +from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole + +# ruff: noqa: E501 +from vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector import ( + DecodeBenchConnector, + DecodeBenchConnectorMetadata, +) +from vllm.forward_context import ForwardContext +from vllm.utils.hashing import sha256 +from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.request import Request + +from .utils import ( + EOS_TOKEN_ID, + create_model_runner_output, + create_scheduler, + create_vllm_config, +) + + +class DecodeBenchTestRunner: + """Test runner for DecodeBenchConnector.""" + + def __init__(self, block_size: int, num_gpu_blocks: int): + self.block_size = block_size + self.num_gpu_blocks = num_gpu_blocks + + self.req_id = -1 + + # Create vllm config with DecodeBenchConnector + vllm_config = create_vllm_config( + block_size=block_size, max_num_batched_tokens=1000 + ) + vllm_config.kv_transfer_config = KVTransferConfig( + kv_connector="DecodeBenchConnector", + kv_role="kv_both", + ) + + self.vllm_config = vllm_config + self.scheduler: Scheduler = create_scheduler( + vllm_config, num_blocks=num_gpu_blocks + ) + + # Create worker-side connector + self.worker_connector = DecodeBenchConnector( + vllm_config, KVConnectorRole.WORKER + ) + + # Create dummy KV caches for testing + # Shape: [num_blocks, 2, num_heads, block_size, head_dim] + # Using simplified shape for testing + num_heads = 4 + head_dim = 64 + self.kv_caches = { + f"layer_{i}": torch.zeros( + num_gpu_blocks, 2, num_heads, block_size, head_dim + ) + for i in range(2) # 2 layers for testing + } + + # Register KV caches with worker connector + self.worker_connector.register_kv_caches(self.kv_caches) + + # Extract scheduler-side connector + scheduler_connector = self.scheduler.connector + assert scheduler_connector is not None + assert isinstance(scheduler_connector, DecodeBenchConnector) + self.scheduler_connector: DecodeBenchConnector = scheduler_connector + + init_none_hash(sha256) + self._block_hasher = get_request_block_hasher(block_size, sha256) + + self._dummy_ctx: ForwardContext = ForwardContext( + no_compile_layers={}, attn_metadata={}, virtual_engine=0 + ) + + def new_request(self, token_ids: list[int]) -> Request: + """Create a new request with given token IDs.""" + self.req_id += 1 + + req = Request( + request_id=str(self.req_id), + prompt_token_ids=token_ids, + sampling_params=SamplingParams(max_tokens=100), + pooling_params=None, + eos_token_id=EOS_TOKEN_ID, + block_hasher=self._block_hasher, + ) + + self.scheduler.add_request(req) + return req + + def run_single_step(self, token_id: int = 0): + """Run a single scheduler + worker step.""" + scheduler_output = self.scheduler.schedule() + + # Get connector metadata + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + assert isinstance(kv_connector_metadata, DecodeBenchConnectorMetadata) + + # Bind metadata and load KV + self.worker_connector.bind_connector_metadata(kv_connector_metadata) + self.worker_connector.start_load_kv(self._dummy_ctx) + + if scheduler_output.total_num_scheduled_tokens > 0: + self.worker_connector.wait_for_save() + + self.worker_connector.clear_connector_metadata() + + # Create model runner output + model_runner_output = create_model_runner_output( + reqs=self.scheduler.running, + token_id=token_id, + ) + + self.scheduler.update_from_output(scheduler_output, model_runner_output) + + return scheduler_output, kv_connector_metadata + + +def test_decode_bench_connector_basic(): + """Test basic functionality of DecodeBenchConnector.""" + block_size = 16 + num_gpu_blocks = 100 + + runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks) + + # Create a request with multiple blocks worth of tokens + num_tokens = block_size * 3 # 3 blocks + token_ids = [1] * num_tokens + + req = runner.new_request(token_ids) + + # Run first step - should fill KV cache with dummy values + scheduler_output, metadata = runner.run_single_step() + + # Check that get_num_new_matched_tokens returned correct value + # Should be num_tokens - 1 (all except the last token for decode) + expected_fill_tokens = num_tokens - 1 + + # Check metadata has the request to fill + assert len(metadata.reqs_to_fill) == 1 + assert req.request_id in metadata.reqs_to_fill + + block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id] + assert num_tokens_to_fill == expected_fill_tokens + + # For standard attention, there's only one group + assert len(block_ids_per_group) == 1 + block_ids = block_ids_per_group[0] + + # Calculate expected number of blocks + expected_num_blocks = (expected_fill_tokens + block_size - 1) // block_size + assert len(block_ids) == expected_num_blocks + + # Verify KV caches were filled with constant value + for layer_name, kv_cache in runner.kv_caches.items(): + for block_id in block_ids: + # Check that the block was filled + block_data = kv_cache[block_id] + # Should be filled with constant value 0.015 + assert torch.allclose(block_data, torch.tensor(0.015)) + + +def test_decode_bench_connector_no_refill(): + """Test that DecodeBenchConnector only fills once per request.""" + block_size = 16 + num_gpu_blocks = 100 + + runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks) + + # Create a request + num_tokens = block_size * 2 + token_ids = [1] * num_tokens + + runner.new_request(token_ids) + + # Run first step - should fill KV cache + _, metadata1 = runner.run_single_step() + assert len(metadata1.reqs_to_fill) == 1 + + # Run second step - should NOT fill again (already filled) + _, metadata2 = runner.run_single_step() + assert len(metadata2.reqs_to_fill) == 0 + + +def test_decode_bench_connector_single_token(): + """Test DecodeBenchConnector with single token request.""" + block_size = 16 + num_gpu_blocks = 100 + + runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks) + + # Create a request with just 1 token + # Should not fill anything (need at least 2 tokens: 1 to fill, 1 to decode) + token_ids = [1] + + runner.new_request(token_ids) + + # Run step - should NOT fill KV cache + _, metadata = runner.run_single_step() + assert len(metadata.reqs_to_fill) == 0 + + +def test_decode_bench_connector_two_tokens(): + """Test DecodeBenchConnector with two token request.""" + block_size = 16 + num_gpu_blocks = 100 + + runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks) + + # Create a request with 2 tokens + # Should fill 1 token (first token), decode the second + token_ids = [1, 2] + + req = runner.new_request(token_ids) + + # Run step + _, metadata = runner.run_single_step() + + assert len(metadata.reqs_to_fill) == 1 + assert req.request_id in metadata.reqs_to_fill + + block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id] + assert num_tokens_to_fill == 1 + # For standard attention, there's only one group + assert len(block_ids_per_group) == 1 + assert len(block_ids_per_group[0]) == 1 # 1 token needs 1 block + + +def test_decode_bench_connector_large_context(): + """Test DecodeBenchConnector with large context size.""" + block_size = 16 + num_gpu_blocks = 1000 + + runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks) + + # Create a request with many blocks + num_blocks = 20 + num_tokens = block_size * num_blocks + token_ids = list(range(num_tokens)) + + req = runner.new_request(token_ids) + + # Run step + _, metadata = runner.run_single_step() + + assert len(metadata.reqs_to_fill) == 1 + assert req.request_id in metadata.reqs_to_fill + + block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id] + + # Should fill all tokens except the last one + expected_fill_tokens = num_tokens - 1 + assert num_tokens_to_fill == expected_fill_tokens + + # For standard attention, there's only one group + assert len(block_ids_per_group) == 1 + block_ids = block_ids_per_group[0] + + # Calculate expected number of blocks + expected_num_blocks = (expected_fill_tokens + block_size - 1) // block_size + assert len(block_ids) == expected_num_blocks + + # Verify blocks were filled + for layer_name, kv_cache in runner.kv_caches.items(): + for block_id in block_ids: + block_data = kv_cache[block_id] + assert torch.allclose(block_data, torch.tensor(0.015)) + + +def test_decode_bench_connector_multiple_requests(): + """Test DecodeBenchConnector with multiple sequential requests.""" + block_size = 16 + num_gpu_blocks = 100 + + runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks) + + # First request + req1 = runner.new_request([1] * (block_size * 2)) + _, metadata1 = runner.run_single_step() + + assert len(metadata1.reqs_to_fill) == 1 + assert req1.request_id in metadata1.reqs_to_fill + + # Complete first request + while runner.scheduler.running: + runner.run_single_step() + + # Add EOS to finish + scheduler_output = runner.scheduler.schedule() + model_runner_output = create_model_runner_output( + reqs=runner.scheduler.running, + token_id=EOS_TOKEN_ID, + use_eos=True, + ) + runner.scheduler.update_from_output(scheduler_output, model_runner_output) + + # Second request - should also get filled + req2 = runner.new_request([2] * (block_size * 3)) + _, metadata2 = runner.run_single_step() + + assert len(metadata2.reqs_to_fill) == 1 + assert req2.request_id in metadata2.reqs_to_fill + + # Different request should have different metadata + _, num_tokens1 = metadata1.reqs_to_fill[req1.request_id] + _, num_tokens2 = metadata2.reqs_to_fill[req2.request_id] + + assert num_tokens1 == block_size * 2 - 1 + assert num_tokens2 == block_size * 3 - 1 + + +def test_decode_bench_connector_partial_block(): + """Test DecodeBenchConnector with partial block filling.""" + block_size = 16 + num_gpu_blocks = 100 + + runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks) + + # Create a request that doesn't align to block boundaries + # e.g., 2.5 blocks worth of tokens + num_tokens = block_size * 2 + block_size // 2 + token_ids = [1] * num_tokens + + req = runner.new_request(token_ids) + + # Run step + _, metadata = runner.run_single_step() + + assert len(metadata.reqs_to_fill) == 1 + assert req.request_id in metadata.reqs_to_fill + + block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id] + + # Should fill all tokens except the last one + expected_fill_tokens = num_tokens - 1 + assert num_tokens_to_fill == expected_fill_tokens + + # For standard attention, there's only one group + assert len(block_ids_per_group) == 1 + block_ids = block_ids_per_group[0] + + # Should allocate 3 blocks to hold the partial data + expected_num_blocks = 3 + assert len(block_ids) == expected_num_blocks + + +def test_decode_bench_connector_concurrent_requests(): + """Test DecodeBenchConnector with multiple concurrent requests in the same batch.""" + block_size = 16 + num_gpu_blocks = 1000 + + runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks) + + # Create multiple requests that will be batched together + req1 = runner.new_request([1] * (block_size * 2)) + req2 = runner.new_request([2] * (block_size * 3)) + req3 = runner.new_request([3] * (block_size * 1)) + + # Run first step - all requests should be filled concurrently + _, metadata = runner.run_single_step() + + # All three requests should be in the metadata + assert len(metadata.reqs_to_fill) == 3 + assert req1.request_id in metadata.reqs_to_fill + assert req2.request_id in metadata.reqs_to_fill + assert req3.request_id in metadata.reqs_to_fill + + # Verify each request has correct fill info + block_ids_per_group1, num_tokens1 = metadata.reqs_to_fill[req1.request_id] + block_ids_per_group2, num_tokens2 = metadata.reqs_to_fill[req2.request_id] + block_ids_per_group3, num_tokens3 = metadata.reqs_to_fill[req3.request_id] + + # Verify token counts (all tokens except last one) + assert num_tokens1 == block_size * 2 - 1 + assert num_tokens2 == block_size * 3 - 1 + assert num_tokens3 == block_size * 1 - 1 + + # Verify block counts for each request + assert len(block_ids_per_group1[0]) == 2 # 2 blocks + assert len(block_ids_per_group2[0]) == 3 # 3 blocks + assert len(block_ids_per_group3[0]) == 1 # 1 block + + # Verify all blocks are filled in KV cache + for req_id, (block_ids_per_group, _) in metadata.reqs_to_fill.items(): + block_ids = block_ids_per_group[0] + for layer_name, kv_cache in runner.kv_caches.items(): + for block_id in block_ids: + block_data = kv_cache[block_id] + assert torch.allclose(block_data, torch.tensor(0.015)) + + # Run second step - should NOT fill again (already filled) + _, metadata2 = runner.run_single_step() + assert len(metadata2.reqs_to_fill) == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py b/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py index 0bb67b574fa1..b5c8f378be18 100644 --- a/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py +++ b/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py @@ -26,7 +26,7 @@ def _make_empty_scheduler_output(): num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, kv_connector_metadata=SharedStorageConnectorMetadata(), ) diff --git a/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py b/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py index 0902fbfe85f3..6b7b2226e758 100644 --- a/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py +++ b/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable +from collections.abc import Callable from unittest.mock import Mock import pytest diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index 74ae3ca9a863..6748532afd97 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -4,9 +4,22 @@ import shutil import tempfile from pathlib import Path +from typing import Any + +import pytest from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1 +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats +from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( + MultiConnector, + MultiKVConnectorStats, +) +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + NixlKVConnectorStats, +) MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" @@ -19,6 +32,27 @@ SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20) +# Test connector with custom stats for testing MultiConnector +class MockConnectorStats(KVConnectorStats): + """Mock stats class for testing.""" + + pass + + +class MockConnector(KVConnectorBase_V1): + """Mock connector that implements build_kv_connector_stats for testing.""" + + @classmethod + def build_kv_connector_stats( + cls, data: dict[str, Any] | None = None + ) -> KVConnectorStats | None: + return MockConnectorStats(data=data) if data is not None else None + + +# Register the mock connector +KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__) + + # Helper function to compare directories recursively def _compare_directories(dir1: Path, dir2: Path) -> bool: """Compares two directories recursively for identical content.""" @@ -80,6 +114,7 @@ def test_multi_shared_storage_connector_consistency(): enforce_eager=True, gpu_memory_utilization=0.5, kv_transfer_config=kv_transfer_config, + disable_hybrid_kv_cache_manager=True, ) # Run generation - this should trigger saving KV cache _ = llm.generate(PROMPTS, SAMPLING_PARAMS) @@ -225,3 +260,337 @@ def test_engine_id_conflict(): assert ids[0] != ids[1], ( f"Engine IDs should be different for different configs. Got {ids}" ) + + +class TestMultiConnectorStats: + """Tests for MultiConnector stats reconstruction and operations.""" + + def test_build_kv_connector_stats_with_none(self): + """Test that build_kv_connector_stats returns empty stats when given None.""" + stats = MultiConnector.build_kv_connector_stats(data=None) + + assert stats is not None + assert isinstance(stats, MultiKVConnectorStats) + assert len(stats.data) == 0 + assert stats.is_empty() + + def test_build_kv_connector_stats_with_empty_dict(self): + """Test that build_kv_connector_stats returns empty stats with empty dict.""" + stats = MultiConnector.build_kv_connector_stats(data={}) + + assert stats is not None + assert isinstance(stats, MultiKVConnectorStats) + assert len(stats.data) == 0 + assert stats.is_empty() + + def test_build_kv_connector_stats_reconstructs_nixl_stats(self): + """Test that NixlConnector stats are properly reconstructed with + correct data.""" + serialized_data = { + "NixlConnector": { + "data": { + "transfer_duration": [1.5, 2.3], + "post_duration": [0.1, 0.2], + "bytes_transferred": [1024, 2048], + "num_descriptors": [10, 20], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + } + } + + stats = MultiConnector.build_kv_connector_stats(data=serialized_data) + + assert "NixlConnector" in stats.data + nixl_stats = stats.data["NixlConnector"] + assert isinstance(nixl_stats, NixlKVConnectorStats) + assert nixl_stats.data["transfer_duration"] == [1.5, 2.3] + assert nixl_stats.data["post_duration"] == [0.1, 0.2] + assert nixl_stats.data["bytes_transferred"] == [1024, 2048] + assert nixl_stats.data["num_descriptors"] == [10, 20] + + def test_build_kv_connector_stats_with_multiple_connectors(self): + """Test reconstruction with multiple connector types that have custom stats.""" + serialized_data = { + "NixlConnector": { + "data": { + "transfer_duration": [1.5], + "post_duration": [0.1], + "bytes_transferred": [1024], + "num_descriptors": [10], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + }, + "MockConnector": {"data": {"mock_field": [1, 2, 3]}}, + } + + stats = MultiConnector.build_kv_connector_stats(data=serialized_data) + + assert stats is not None + assert isinstance(stats, MultiKVConnectorStats) + # Both connectors should be reconstructed + assert len(stats.data) == 2 + assert "NixlConnector" in stats.data + assert "MockConnector" in stats.data + assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats) + assert isinstance(stats.data["MockConnector"], MockConnectorStats) + # Verify data is preserved + assert stats.data["MockConnector"].data == {"mock_field": [1, 2, 3]} + + def test_build_kv_connector_stats_raises_error_for_unknown_connector(self): + """Test that unknown connectors raise an error.""" + serialized_data = { + "UnknownConnector": {"data": {"some_field": [1, 2, 3]}}, + "NixlConnector": { + "data": { + "transfer_duration": [1.5], + "post_duration": [0.1], + "bytes_transferred": [1024], + "num_descriptors": [10], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + }, + } + + with pytest.raises( + ValueError, match="Connector 'UnknownConnector' is not registered." + ): + MultiConnector.build_kv_connector_stats(data=serialized_data) + + def test_build_kv_connector_stats_with_already_instantiated_objects(self): + """Test that already-instantiated stats objects are preserved (same process).""" + # This simulates the in-process case where stats are not serialized + nixl_stats = NixlKVConnectorStats( + data={ + "transfer_duration": [1.5], + "post_duration": [0.1], + "bytes_transferred": [1024], + "num_descriptors": [10], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + ) + mock_stats = MockConnectorStats(data={"mock_field": [1, 2, 3]}) + + data_with_objects = { + "NixlConnector": nixl_stats, + "MockConnector": mock_stats, + } + + stats = MultiConnector.build_kv_connector_stats(data=data_with_objects) + + assert stats is not None + assert isinstance(stats, MultiKVConnectorStats) + assert len(stats.data) == 2 + # Verify objects are preserved as-is + assert stats.data["NixlConnector"] is nixl_stats + assert stats.data["MockConnector"] is mock_stats + + def test_build_kv_connector_stats_with_mixed_objects_and_dicts(self): + """Test handling mixed already-instantiated and serialized stats.""" + # This can happen during transition or partial serialization + nixl_stats = NixlKVConnectorStats( + data={ + "transfer_duration": [1.5], + "post_duration": [0.1], + "bytes_transferred": [1024], + "num_descriptors": [10], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + ) + + mixed_data = { + "NixlConnector": nixl_stats, # Already instantiated + "MockConnector": {"data": {"mock_field": [1, 2, 3]}}, # Serialized + } + + stats = MultiConnector.build_kv_connector_stats(data=mixed_data) + + assert stats is not None + assert isinstance(stats, MultiKVConnectorStats) + assert len(stats.data) == 2 + # Instantiated object preserved + assert stats.data["NixlConnector"] is nixl_stats + # Serialized object reconstructed + assert isinstance(stats.data["MockConnector"], MockConnectorStats) + assert stats.data["MockConnector"].data == {"mock_field": [1, 2, 3]} + + def test_build_kv_connector_stats_skips_connectors_without_custom_stats(self): + """Test that connectors without custom stats (return None) are skipped.""" + # SharedStorageConnector doesn't override build_kv_connector_stats, + # so it returns None and should be skipped + serialized_data = { + "NixlConnector": { + "data": { + "transfer_duration": [1.5], + "post_duration": [0.1], + "bytes_transferred": [1024], + "num_descriptors": [10], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + }, + "SharedStorageConnector": {"data": {"some_field": [1, 2, 3]}}, + } + + stats = MultiConnector.build_kv_connector_stats(data=serialized_data) + + assert stats is not None + assert isinstance(stats, MultiKVConnectorStats) + # Only NixlConnector should be reconstructed + assert len(stats.data) == 1 + assert "NixlConnector" in stats.data + assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats) + # SharedStorageConnector should be skipped (returns None) + assert "SharedStorageConnector" not in stats.data + + def test_build_kv_connector_stats_handles_malformed_data(self): + """Test that malformed data raises appropriate errors.""" + serialized_data = { + "NixlConnector": {"wrong_field": {"transfer_duration": [1.5]}} + } + + with pytest.raises(AssertionError, match="Expected a dict with a 'data' field"): + MultiConnector.build_kv_connector_stats(data=serialized_data) + + def test_aggregate_same_connector(self): + """Test aggregating stats from the same connector type.""" + stats1 = MultiKVConnectorStats( + data={ + "NixlConnector": NixlKVConnectorStats( + data={ + "transfer_duration": [1.0], + "post_duration": [0.1], + "bytes_transferred": [1024], + "num_descriptors": [10], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + ) + } + ) + + stats2 = MultiKVConnectorStats( + data={ + "NixlConnector": NixlKVConnectorStats( + data={ + "transfer_duration": [2.0], + "post_duration": [0.2], + "bytes_transferred": [2048], + "num_descriptors": [20], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + ) + } + ) + + result = stats1.aggregate(stats2) + + assert result is stats1 # Should return self + assert "NixlConnector" in result.data + nixl_stats = result.data["NixlConnector"] + assert nixl_stats.data["transfer_duration"] == [1.0, 2.0] + assert nixl_stats.data["post_duration"] == [0.1, 0.2] + assert nixl_stats.data["bytes_transferred"] == [1024, 2048] + assert nixl_stats.data["num_descriptors"] == [10, 20] + + def test_aggregate_new_connector(self): + """Test aggregating stats when a new connector type appears.""" + from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorStats, + ) + + stats1 = MultiKVConnectorStats( + data={ + "NixlConnector": NixlKVConnectorStats( + data={ + "transfer_duration": [1.0], + "post_duration": [0.1], + "bytes_transferred": [1024], + "num_descriptors": [10], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + ) + } + ) + + stats2 = MultiKVConnectorStats( + data={"SharedStorageConnector": KVConnectorStats(data={"field": [1, 2]})} + ) + + result = stats1.aggregate(stats2) + + assert "NixlConnector" in result.data + assert "SharedStorageConnector" in result.data + + def test_reduce(self): + """Test that reduce() correctly reduces all nested connector stats.""" + stats = MultiKVConnectorStats( + data={ + "NixlConnector": NixlKVConnectorStats( + data={ + "transfer_duration": [1.0, 2.0], + "post_duration": [0.1, 0.2], + "bytes_transferred": [1024, 2048], + "num_descriptors": [10, 20], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + ) + } + ) + + reduced = stats.reduce() + + assert "NixlConnector" in reduced + assert isinstance(reduced["NixlConnector"], dict) + # Check that the stats were reduced (should have aggregated values) + assert "Num successful transfers" in reduced["NixlConnector"] + assert reduced["NixlConnector"]["Num successful transfers"] == 2 + + def test_reset(self): + """Test that reset() resets all nested connector stats.""" + stats = MultiKVConnectorStats( + data={ + "NixlConnector": NixlKVConnectorStats( + data={ + "transfer_duration": [1.0, 2.0], + "post_duration": [0.1, 0.2], + "bytes_transferred": [1024, 2048], + "num_descriptors": [10, 20], + "num_failed_transfers": [], + "num_failed_notifications": [], + } + ) + } + ) + + assert not stats.is_empty() + + stats.reset() + + # After reset, stats should be empty + assert stats.is_empty() + nixl_stats = stats.data["NixlConnector"] + assert len(nixl_stats.data["transfer_duration"]) == 0 + + def test_is_empty_with_multiple_connectors(self): + """Test is_empty() returns correct value with multiple connectors.""" + # All empty + stats = MultiKVConnectorStats( + data={ + "NixlConnector": NixlKVConnectorStats(data={}), + } + ) + # Initialize empty stats + stats.data["NixlConnector"].reset() + assert stats.is_empty() + + # One non-empty + stats.data["NixlConnector"].data["transfer_duration"].append(1.0) + assert not stats.is_empty() diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index a1f53cb25563..445d115010cd 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -9,7 +9,6 @@ import time import uuid from collections import defaultdict -from typing import Optional from unittest.mock import patch import pytest @@ -154,7 +153,7 @@ def make_prepped_xfer( local_block_descs_ids: list[int], remote_xfer_side_handle: int, remote_block_descs_ids: list[int], - notif_msg: Optional[bytes] = None, + notif_msg: bytes | None = None, ) -> int: return uuid.uuid4().int @@ -191,7 +190,6 @@ def _make_fake_nixl_pkg(): # Copy of FakeNixlWrapper implementation for Ray workers import uuid from collections import defaultdict -from typing import Optional {fake_nixl_source} @@ -288,9 +286,12 @@ def test_prompt_less_than_block_size(): class FakeNixlConnectorWorker(NixlConnectorWorker): REMOTE_ENGINE_ID = "remote_engine" - def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs): + def __init__( + self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs + ): super().__init__(*args, **kwargs) self._hand_shake_latency = hand_shake_latency + self.kv_cache_layout = kv_cache_layout def _nixl_handshake( self, host: str, port: int, remote_tp_size: int, expected_engine_id: str @@ -564,12 +565,63 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init): kv_cache_layout=mismatched_layout, ) - # We don't check layout for homogeneous TP and MLA for now, as the - # whole block is moved. - worker.add_remote_agent(meta, remote_tp_size=2) + with pytest.raises(RuntimeError): + # mismatched layout is expected to fail + worker.add_remote_agent(meta, remote_tp_size=2) with pytest.raises(AssertionError): worker.add_remote_agent(meta, remote_tp_size=1) + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental( + self, dist_init + ): + """ + Verify that adding a remote agent fails if kv_cache_layout differs. + This test is only relevant for heterogeneous TP. + """ + vllm_config = create_vllm_config(enable_permute_local_kv=True) + + # Mock TP world size to 2 to force heterogeneous TP when + # remote_tp_size=1 + with patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501 + return_value=2, + ): + # Initialize connector and worker (with fake NIXL wrapper) + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, + connector.engine_id, + hand_shake_latency=0, + kv_cache_layout="NHD", + ) + worker = connector.connector_worker + + # Minimal local registration params used by add_remote_agent + worker.slot_size_per_layer = [2048] + worker.block_len_per_layer = [2048 * worker.block_size] + worker.num_blocks = 1 + worker.dst_num_blocks[worker.engine_id] = worker.num_blocks + + # Metadata with different kv_cache_layout than local worker + meta = NixlAgentMetadata( + engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + agent_metadata=FakeNixlWrapper.AGENT_METADATA, + kv_caches_base_addr=[0], + num_blocks=1, + # prefill TP=1, decode TP=2, remote block_lens is double to local + block_lens=[i * 2 for i in worker.block_len_per_layer], + attn_backend_name=worker.backend_name, + kv_cache_layout="HND", + ) + + # We don't check layout for homogeneous TP and MLA for now, as the + # whole block is moved. + worker.add_remote_agent(meta, remote_tp_size=1) + # NOTE: resource cleanup in mp backend is a bit finicky, so the order in which # we put here is important. First run ray, it will clean up the resources, then @@ -651,7 +703,7 @@ def test_kv_connector_stats_aggregation(): # Create KVOutputAggregator for 3 workers (simulating TP=3), same thing # done in MultiprocExecutor.execute_model - aggregator = KVOutputAggregator(world_size=3) + aggregator = KVOutputAggregator(expected_finished_count=3) # Create stats for multiple workers with different transfer patterns worker1_stats = NixlKVConnectorStats() @@ -716,7 +768,7 @@ def test_multi_kv_connector_stats_aggregation(): KVOutputAggregator (used by MultiprocExecutor). """ - aggregator = KVOutputAggregator(world_size=3) + aggregator = KVOutputAggregator(expected_finished_count=3) from dataclasses import dataclass @@ -785,6 +837,75 @@ def make_multi_stats(nixl_count: int, foo_count: int) -> MultiKVConnectorStats: assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6 +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, +) +def test_scheduler_kv_connector_stats_aggregation(): + """Test scheduler and worker KV connector stats aggregation.""" + from vllm.v1.core.sched.output import SchedulerOutput + + scheduler = create_scheduler(create_vllm_config()) + + # Worker stats with transfer metrics + worker_stats = NixlKVConnectorStats() + worker_stats.record_transfer(get_default_xfer_telemetry()) + worker_stats.data["remote_tokens"] = [] + + # Scheduler stats with custom metric (needs dummy transfer to avoid being skipped) + scheduler_stats = NixlKVConnectorStats() + scheduler_stats.data.update( + { # dummy transfer just for testing, to bypass is_empty() check + "transfer_duration": [0], + "post_duration": [0], + "bytes_transferred": [0], + "num_descriptors": [0], + "remote_tokens": [128], + } + ) + + # Mock the scheduler connector's stats method + scheduler.connector.get_kv_connector_stats = lambda: MultiKVConnectorStats( + data={"NixlConnector": scheduler_stats} + ) + + model_output = ModelRunnerOutput( + req_ids=["req_0"], + req_id_to_index={"req_0": 0}, + sampled_token_ids=[[123]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[None], + kv_connector_output=KVConnectorOutput( + kv_connector_stats=MultiKVConnectorStats( + data={"NixlConnector": worker_stats} + ) + ), + ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=None, + num_scheduled_tokens={"req_0": 1}, + total_num_scheduled_tokens=1, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0], + finished_req_ids=set(), + free_encoder_mm_hashes=set(), + structured_output_request_ids={}, + grammar_bitmask=None, + ) + + engine_core_outputs = scheduler.update_from_output(scheduler_output, model_output) + + final_stats = next( + iter(engine_core_outputs.values()) + ).scheduler_stats.kv_connector_stats + nixl_stats = final_stats["NixlConnector"] + assert nixl_stats.num_successful_transfers == 2 + assert nixl_stats.data["remote_tokens"] == [128] + + @pytest.mark.parametrize("distributed_executor_backend", ["ray", None]) @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", @@ -811,12 +932,20 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): "gpu_memory_utilization": 0.5, "kv_transfer_config": kv_transfer_config, "distributed_executor_backend": distributed_executor_backend, + "disable_hybrid_kv_cache_manager": True, } timeout = 6 monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout)) + def run_test_and_cleanup(): + llm = LLM(**llm_kwargs) + try: + _run_abort_timeout_test(llm, timeout) + finally: + llm.llm_engine.engine_core.shutdown() + # Build runtime_env only if we're using Ray if distributed_executor_backend == "ray": with _make_fake_nixl_pkg() as working_dir: @@ -829,15 +958,16 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): }, } ray.init(runtime_env=runtime_env) - - _run_abort_timeout_test(llm_kwargs, timeout) + try: + run_test_and_cleanup() + finally: + ray.shutdown() else: - _run_abort_timeout_test(llm_kwargs, timeout) + run_test_and_cleanup() -def _run_abort_timeout_test(llm_kwargs: dict, timeout: int): +def _run_abort_timeout_test(llm: LLM, timeout: int): """Helper function to run the abort timeout test logic.""" - llm = LLM(**llm_kwargs) remote_prefill_opts = { "do_remote_decode": True, "do_remote_prefill": False, @@ -921,7 +1051,7 @@ def test_register_kv_caches(dist_init): ), patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread" - ), + ) as mock_thread, ): # noqa: E501 # Create connector connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) @@ -933,6 +1063,9 @@ def test_register_kv_caches(dist_init): mock_wrapper_instance = mock_nixl_wrapper.return_value connector.connector_worker.nixl_wrapper = mock_wrapper_instance + # Reassure the shutdown() check that the thread is terminated + mock_thread.return_value.is_alive.return_value = False + # Execute register_kv_caches connector.register_kv_caches(kv_caches) @@ -982,7 +1115,7 @@ def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]: return {"oot": ("oot",)} @classmethod - def get_nixl_memory_type(cls) -> Optional[str]: + def get_nixl_memory_type(cls) -> str | None: """ Returns the nixl memory type for the current platform. """ @@ -1050,6 +1183,7 @@ def test_shutdown_cleans_up_resources(dist_init): with ( patch.object(worker, "_handshake_initiation_executor") as mock_exec, patch.object(worker, "_nixl_handshake_listener_t") as mock_listener, + patch.object(worker, "_nixl_handshake_listener_stop_event") as mock_event, patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer, patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist, patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent, @@ -1061,6 +1195,8 @@ def test_shutdown_cleans_up_resources(dist_init): worker._remote_agents = {"engine1": {0: "agent1"}} worker._registered_descs = ["desc1", "desc2"] + mock_listener.is_alive.return_value = False + worker.shutdown() # Test idempotency @@ -1068,7 +1204,8 @@ def test_shutdown_cleans_up_resources(dist_init): worker.shutdown() mock_exec.shutdown.assert_called_with(wait=False) - mock_listener.join.assert_called_once_with(timeout=0) + mock_event.set.assert_called_once() + mock_listener.join.assert_called_once_with(timeout=1.0) mock_rel_xfer.assert_called_once_with(123) assert mock_rel_dlist.call_count == 2 @@ -1144,3 +1281,145 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init): # After abort, the worker should not keep tracking it as "in-batch" assert req.request_id not in connector.connector_worker._reqs_to_process #### Model Runner end #### + + +class FailingNixlWrapper(FakeNixlWrapper): + """Mock NixlWrapper that fails on specific operations.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fail_handshake = False + self.fail_transfer_setup = False + self.fail_send_notif = False + + def add_remote_agent(self, agent_metadata: bytes) -> str: + if self.fail_handshake: + from zmq.error import Again + + raise Again("Simulated timeout failure") + return super().add_remote_agent(agent_metadata) + + def make_prepped_xfer( + self, + xfer_type: str, + local_xfer_side_handle: int, + local_block_descs_ids: list[int], + remote_xfer_side_handle: int, + remote_block_descs_ids: list[int], + notif_msg: bytes | None = None, + ) -> int: + if self.fail_transfer_setup: + # classic RuntimeError to simulate failure + raise RuntimeError("BAD STATUS") + return super().make_prepped_xfer( + xfer_type, + local_xfer_side_handle, + local_block_descs_ids, + remote_xfer_side_handle, + remote_block_descs_ids, + notif_msg, + ) + + def send_notif(self, agent_name: str, notif_msg: bytes) -> None: + if self.fail_send_notif: + raise RuntimeError("Simulated send_notif failure") + return super().send_notif(agent_name, notif_msg) + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FailingNixlWrapper, +) +def test_handshake_failure_returns_finished(dist_init): + """Test that handshake failures mark blocks invalid and return via get_finished.""" + vllm_config = create_vllm_config() + + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0.1 + ) + connector.connector_worker.nixl_wrapper.fail_handshake = True + + request_id = "test_handshake_fail" + metadata = NixlConnectorMetadata() + metadata.add_new_req( + request_id=request_id, + local_block_ids=[1, 2, 3], + kv_transfer_params={ + "remote_block_ids": [4, 5, 6], + "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": "localhost", + "remote_port": 1234, + "remote_tp_size": 1, + }, + ) + connector.bind_connector_metadata(metadata) + + dummy_ctx = ForwardContext( + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + ) + connector.start_load_kv(dummy_ctx) + + # Wait for handshake to fail + time.sleep(0.3) + + # Check that blocks were marked invalid + invalid_blocks = connector.get_block_ids_with_load_errors() + assert invalid_blocks == {1, 2, 3} + + # Check that request appears in get_finished + _, done_recving = connector.get_finished(finished_req_ids=set()) + assert request_id in done_recving + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FailingNixlWrapper, +) +def test_transfer_setup_failure_returns_finished(dist_init): + """Test that transfer setup failures mark blocks invalid + and return via get_finished.""" + vllm_config = create_vllm_config() + + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + connector.connector_worker.nixl_wrapper.fail_transfer_setup = True + + request_id = "test_transfer_fail" + metadata = NixlConnectorMetadata() + metadata.add_new_req( + request_id=request_id, + local_block_ids=[7, 8, 9], + kv_transfer_params={ + "remote_block_ids": [10, 11, 12], + "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": "localhost", + "remote_port": 1234, + "remote_tp_size": 1, + }, + ) + connector.bind_connector_metadata(metadata) + + dummy_ctx = ForwardContext( + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + ) + connector.start_load_kv(dummy_ctx) + + # Wait for handshake to complete and process ready_requests + connector.bind_connector_metadata(NixlConnectorMetadata()) + time.sleep(0.1) + connector.start_load_kv(dummy_ctx) + + # check that blocks were marked invalid + invalid_blocks = connector.get_block_ids_with_load_errors() + assert invalid_blocks == {7, 8, 9} + + # ensure request appears in get_finished + _, done_recving = connector.get_finished(finished_req_ids=set()) + assert request_id in done_recving diff --git a/tests/v1/kv_connector/unit/test_offloading_connector.py b/tests/v1/kv_connector/unit/test_offloading_connector.py index 46a5c097094e..23b6c4802d10 100644 --- a/tests/v1/kv_connector/unit/test_offloading_connector.py +++ b/tests/v1/kv_connector/unit/test_offloading_connector.py @@ -18,7 +18,7 @@ OffloadingConnectorMetadata, ) from vllm.forward_context import ForwardContext -from vllm.utils import sha256 +from vllm.utils.hashing import sha256 from vllm.v1.core.kv_cache_utils import ( BlockHash, get_request_block_hasher, diff --git a/tests/v1/kv_connector/unit/test_output_aggreagator.py b/tests/v1/kv_connector/unit/test_output_aggregator.py similarity index 71% rename from tests/v1/kv_connector/unit/test_output_aggreagator.py rename to tests/v1/kv_connector/unit/test_output_aggregator.py index d05cbe1a2fd4..4dba203ebc7d 100644 --- a/tests/v1/kv_connector/unit/test_output_aggreagator.py +++ b/tests/v1/kv_connector/unit/test_output_aggregator.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from concurrent.futures import Future -from typing import Optional import pytest @@ -14,14 +13,16 @@ class DummyModelRunnerOutput(ModelRunnerOutput): def __init__( self, - finished_sending: Optional[set[str]] = None, - finished_recving: Optional[set[str]] = None, - invalid_block_ids: Optional[set[int]] = None, + finished_sending: set[str] | None = None, + finished_recving: set[str] | None = None, + invalid_block_ids: set[int] | None = None, + expected_finished_count: int = 0, ): self.kv_connector_output = KVConnectorOutput( finished_sending=finished_sending, finished_recving=finished_recving, invalid_block_ids=invalid_block_ids or set(), + expected_finished_count=expected_finished_count, ) def __repr__(self): @@ -34,7 +35,7 @@ def __repr__(self): def test_aggregate_workers_output(): - aggregator = KVOutputAggregator(world_size=2) + aggregator = KVOutputAggregator(expected_finished_count=2) output1 = DummyModelRunnerOutput() output2 = DummyModelRunnerOutput() @@ -86,7 +87,7 @@ def test_aggregate_workers_output(): def test_async_aggregate_workers_output(): - aggregator = KVOutputAggregator(world_size=2) + aggregator = KVOutputAggregator(expected_finished_count=2) future1: Future[DummyModelRunnerOutput] = Future() future2: Future[DummyModelRunnerOutput] = Future() @@ -159,3 +160,40 @@ def test_async_aggregate_workers_output(): assert aggregated.finished_sending is None assert aggregated.finished_recving == {"req2"} assert aggregated.invalid_block_ids == {3, 4, 5} + + +def test_aggregate_workers_output_with_expected_finished_count(): + # We create the aggregator expecting to collect from 4 workers + aggregator = KVOutputAggregator(expected_finished_count=4) + assert aggregator._expected_finished_count == 4 + # Some request with default expected finished requests + output1 = DummyModelRunnerOutput(finished_sending={"req1"}) + aggregated = aggregator.aggregate([output1]) + # still expecting to collect from 4 workers + assert aggregator._send_remaining_count["req1"] == 3 + assert not aggregated.kv_connector_output.finished_sending + assert not aggregated.kv_connector_output.finished_recving + + # Workers discover and find that in this setup they only need to + # collect from 2 + output1 = DummyModelRunnerOutput( + finished_sending={"req1"}, expected_finished_count=2 + ) + output2 = DummyModelRunnerOutput( + finished_recving={"req2"}, expected_finished_count=2 + ) + output3 = DummyModelRunnerOutput(finished_recving={"req2"}) + # Req2 only needs 2 acks + aggregated = aggregator.aggregate([output1, output2, output3]) + assert aggregated.kv_connector_output.expected_finished_count == 2 + + assert not aggregated.kv_connector_output.finished_sending + + # Req2 is finished + assert "req2" not in aggregator._recv_remaining_count + assert aggregated.kv_connector_output.finished_recving == {"req2"} + + # Req1 is still waiting for 2 more acks (expected_finished_count has no effect) + # NOTE: This is to showcase dynamic update. Workers are responsible for + # ensuring "req1" termination in this case + assert aggregator._send_remaining_count["req1"] == 2 diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index e0404186eb2d..b2ec2ddfb64d 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -43,6 +43,7 @@ def test_basic_lifecycle(): # STEP (1): Prefill. # (1a): schedule() scheduler_output = scheduler.schedule() + assert len(scheduler.requests) == 1 assert len(scheduler.running) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1 @@ -67,6 +68,7 @@ def test_basic_lifecycle(): assert len(scheduler.waiting) == 0 # ... but blocks should not be freed. + assert len(scheduler.requests) == 1 blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ 0 ].req_to_blocks[request_id] @@ -76,6 +78,7 @@ def test_basic_lifecycle(): # STEP (2): Send Finished to PB. # (2a): schedule() - pass finished request to PB. scheduler_output = scheduler.schedule() + assert len(scheduler.requests) == 1 assert len(scheduler.running) == 0 assert len(scheduler_output.finished_req_ids) == 1 assert request_id in scheduler_output.finished_req_ids @@ -92,6 +95,7 @@ def test_basic_lifecycle(): # STEP (3): Finished sending. # (3a): schedule() - pass finished request to PB. scheduler_output = scheduler.schedule() + assert len(scheduler.requests) == 1 assert len(scheduler.running) == 0 assert len(scheduler_output.finished_req_ids) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0 @@ -133,6 +137,7 @@ def test_short_prompt_lifecycle(): # STEP (1): Prefill. # (1a): schedule() scheduler_output = scheduler.schedule() + assert len(scheduler.requests) == 1 assert len(scheduler.running) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1 @@ -178,7 +183,7 @@ def test_prefix_cache_lifecycle(): reqs=[request_normal], use_eos=True ) scheduler.update_from_output(scheduler_output, model_runner_output) - scheduler.schedule() + scheduler_output = scheduler.schedule() scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) ##################### @@ -213,3 +218,45 @@ def test_prefix_cache_lifecycle(): ) scheduler.update_from_output(scheduler_output, model_runner_output) assert_scheduler_empty(scheduler) + + +def test_abort_during_kv_transfer(): + """Test aborting request does not release blocks for remote decode.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # Prime the KVCache. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_decode=True, + ) + + scheduler.add_request(request) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request]) + scheduler.update_from_output(scheduler_output, model_runner_output) + scheduler_output = scheduler.schedule() + scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) + + # Request removed from PB but blocks should not be freed. + assert len(scheduler.requests) == 1 + + # Abort the request, and check the blocks are still not freed + scheduler.finish_requests([request.request_id], RequestStatus.FINISHED_ABORTED) + assert len(scheduler.requests) == 1 + + # Simulate a finished sending notification + scheduler_output = scheduler.schedule() + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.kv_connector_output = KVConnectorOutput( + finished_sending=[request.request_id] + ) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert_scheduler_empty(scheduler) diff --git a/tests/v1/kv_connector/unit/test_shared_storage_connector.py b/tests/v1/kv_connector/unit/test_shared_storage_connector.py index e7013a794a8c..6040ed5a6806 100644 --- a/tests/v1/kv_connector/unit/test_shared_storage_connector.py +++ b/tests/v1/kv_connector/unit/test_shared_storage_connector.py @@ -132,6 +132,7 @@ def test_shared_storage_connector_hashes(tmp_path): enforce_eager=True, kv_transfer_config=kv_transfer_config, limit_mm_per_prompt={"image": 2}, + disable_hybrid_kv_cache_manager=True, ) # don't put this import at the top level diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 24c0bd51216d..46ea46e53084 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import tempfile from collections import defaultdict +from collections.abc import Callable from itertools import count -from typing import Any, Callable, Optional +from typing import Any import torch @@ -20,7 +21,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa SharedStorageConnector, ) -from vllm.utils import sha256 +from vllm.utils.hashing import sha256 from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash from vllm.v1.core.sched.scheduler import Scheduler @@ -82,6 +83,7 @@ def create_vllm_config( block_size: int = 16, max_model_len: int = 10000, enable_chunked_prefill: bool = True, + enable_permute_local_kv: bool = False, ) -> VllmConfig: """Initialize VllmConfig For Testing.""" scheduler_config = SchedulerConfig( @@ -89,6 +91,9 @@ def create_vllm_config( max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len, enable_chunked_prefill=enable_chunked_prefill, + # Disable hybrid KV cache manager for testing + # Should be removed after we support hybrid KV cache manager-based testing. + disable_hybrid_kv_cache_manager=True, ) model_config = ModelConfig( model=model, @@ -107,6 +112,7 @@ def create_vllm_config( kv_transfer_config = KVTransferConfig( kv_connector="NixlConnector", kv_role="kv_both", + enable_permute_local_kv=enable_permute_local_kv, ) return VllmConfig( scheduler_config=scheduler_config, @@ -138,6 +144,7 @@ def create_scheduler( kv_cache_config=kv_cache_config, log_stats=True, structured_output_manager=StructuredOutputManager(vllm_config), + block_size=block_size, ) @@ -146,7 +153,7 @@ def create_scheduler( def create_request( - request_id: Optional[int] = None, + request_id: int | None = None, num_tokens: int = 10, common_prefix_len=0, max_tokens: int = 16, @@ -167,7 +174,7 @@ def create_request( init_none_hash(hash_fn) _none_hash_initialized = True - kv_transfer_params: Optional[dict[str, Any]] = None + kv_transfer_params: dict[str, Any] | None = None if do_remote_decode: assert not do_remote_prefill @@ -204,9 +211,9 @@ def create_request( def create_model_runner_output( reqs: list[Request], - finished_sending: Optional[set[str]] = None, - finished_recving: Optional[set[str]] = None, - invalid_block_ids: Optional[set[int]] = None, + finished_sending: set[str] | None = None, + finished_recving: set[str] | None = None, + invalid_block_ids: set[int] | None = None, use_eos: bool = False, token_id: int = 0, ) -> ModelRunnerOutput: diff --git a/tests/v1/kv_offload/test_cpu_gpu.py b/tests/v1/kv_offload/test_cpu_gpu.py index 81b57f1ca0c8..0d4fa344d298 100644 --- a/tests/v1/kv_offload/test_cpu_gpu.py +++ b/tests/v1/kv_offload/test_cpu_gpu.py @@ -8,11 +8,20 @@ from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend -from vllm.v1.attention.backends.flashinfer import FlashInferBackend -from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLABackend from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler +BACKENDS_TO_TEST = [FlashAttentionBackend] + +if not current_platform.is_rocm(): + from vllm.v1.attention.backends.flashinfer import FlashInferBackend + + BACKENDS_TO_TEST.append(FlashInferBackend) + + from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLABackend + + BACKENDS_TO_TEST.append(FlashAttnMLABackend) + NUM_GPU_BLOCKS = [64] NUM_CPU_BLOCKS = [256] GPU_BLOCK_SIZES = [16] @@ -55,8 +64,8 @@ def test_transfer( ) -> None: current_platform.seed_everything(seed) - # create per-layer GPU KV caches - attn_backends_list = [FlashAttentionBackend, FlashInferBackend, FlashAttnMLABackend] + # create per-layer GPU KV caches based on available attn_backends + attn_backends_list = BACKENDS_TO_TEST gpu_caches = {} attn_backends = {} diff --git a/tests/v1/kv_offload/test_cpu_manager.py b/tests/v1/kv_offload/test_cpu_manager.py index 57884f846b51..4f90ca022cef 100644 --- a/tests/v1/kv_offload/test_cpu_manager.py +++ b/tests/v1/kv_offload/test_cpu_manager.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable from dataclasses import dataclass -from typing import Optional import numpy as np @@ -29,7 +28,7 @@ def to_hashes(int_hashes: list[int]) -> list[BlockHash]: def verify_store_output( - prepare_store_output: Optional[PrepareStoreOutput], + prepare_store_output: PrepareStoreOutput | None, expected_prepare_store_output: ExpectedPrepareStoreOutput, ): assert prepare_store_output is not None diff --git a/tests/v1/kv_offload/test_cpu_offloading.py b/tests/v1/kv_offload/test_cpu_offloading.py index 0d90cc715fd4..e9c255b1ee99 100644 --- a/tests/v1/kv_offload/test_cpu_offloading.py +++ b/tests/v1/kv_offload/test_cpu_offloading.py @@ -27,6 +27,7 @@ def test_cpu_offloading(cpu_block_size: int) -> None: model="meta-llama/Llama-3.2-1B-Instruct", gpu_memory_utilization=0.5, kv_transfer_config=kv_transfer_config, + disable_hybrid_kv_cache_manager=True, ) prompts = ["Hi " * 100] diff --git a/tests/v1/logits_processors/test_correctness.py b/tests/v1/logits_processors/test_correctness.py index 538b6281f5a0..dac7ffed69d4 100644 --- a/tests/v1/logits_processors/test_correctness.py +++ b/tests/v1/logits_processors/test_correctness.py @@ -3,7 +3,7 @@ import random from collections.abc import Callable -from typing import NamedTuple, Optional, Union +from typing import NamedTuple, TypeAlias import numpy as np import pytest @@ -21,7 +21,7 @@ from vllm.config import VllmConfig from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams -from vllm.utils import is_pin_memory_available +from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.sample.logits_processor import ( BatchUpdate, BatchUpdateBuilder, @@ -48,7 +48,7 @@ STR_NO_LOGITPROC = "none" # LogitsProcessor subclass or "none" -LogitprocType = Union[type[LogitsProcessor], str] +LogitprocType: TypeAlias = type[LogitsProcessor] | str class LogitsProcsRequestParams: @@ -435,7 +435,7 @@ class LogitsprocTestHelpers(NamedTuple): """Supports setting up and validating logitsprocs unit tests.""" eval_fxn: Callable - gen_request_fxn: Optional[Callable] = None + gen_request_fxn: Callable | None = None logitsprocs_test_mapping = { @@ -471,7 +471,7 @@ def _generate_fake_step_update( workload_params: list[LogitsProcsRequestParams], wdx: int, batch_update_builder: BatchUpdateBuilder, -) -> tuple[Optional[BatchUpdate], int, int]: +) -> tuple[BatchUpdate | None, int, int]: batch_size = len(persistent_batch) workload_size = len(workload_params) workload_reqs_remaining = workload_size - wdx diff --git a/tests/v1/logits_processors/test_custom_offline.py b/tests/v1/logits_processors/test_custom_offline.py index 95ddb1849169..1899737737f4 100644 --- a/tests/v1/logits_processors/test_custom_offline.py +++ b/tests/v1/logits_processors/test_custom_offline.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random import sys -from typing import Any, Union +from typing import Any import pytest @@ -159,7 +159,7 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource _run_test({}, logitproc_loaded=True) return - kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {} + kwargs: dict[str, list[str | type[LogitsProcessor]]] = {} if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN: # Scenario: load logitproc based on fully-qualified class name (FQCN) # Inject dummy module which defines logitproc diff --git a/tests/v1/logits_processors/test_custom_online.py b/tests/v1/logits_processors/test_custom_online.py index 9c5b4ff0ba17..0d902b46bed5 100644 --- a/tests/v1/logits_processors/test_custom_online.py +++ b/tests/v1/logits_processors/test_custom_online.py @@ -4,7 +4,7 @@ import os import random import sys -from typing import Any, Optional +from typing import Any import openai import pytest @@ -25,7 +25,7 @@ def _server_with_logitproc_entrypoint( - env_dict: Optional[dict[str, str]], + env_dict: dict[str, str] | None, model: str, vllm_serve_args: list[str], ) -> None: @@ -48,7 +48,7 @@ def _server_with_logitproc_entrypoint( def _server_with_logitproc_module( - env_dict: Optional[dict[str, str]], + env_dict: dict[str, str] | None, model: str, vllm_serve_args: list[str], ) -> None: diff --git a/tests/v1/logits_processors/utils.py b/tests/v1/logits_processors/utils.py index 9a1d5505a5f9..36cffebb3b45 100644 --- a/tests/v1/logits_processors/utils.py +++ b/tests/v1/logits_processors/utils.py @@ -3,7 +3,7 @@ import types from enum import Enum, auto -from typing import Any, Optional +from typing import Any import torch @@ -61,7 +61,7 @@ def is_argmax_invariant(self) -> bool: """Never impacts greedy sampling""" return False - def update_state(self, batch_update: Optional[BatchUpdate]): + def update_state(self, batch_update: BatchUpdate | None): process_dict_updates( self.req_info, batch_update, @@ -145,7 +145,7 @@ def is_argmax_invariant(self) -> bool: def new_req_logits_processor( self, params: SamplingParams, - ) -> Optional[RequestLogitsProcessor]: + ) -> RequestLogitsProcessor | None: """This method returns a new request-level logits processor, customized to the `target_token` value associated with a particular request. @@ -159,7 +159,7 @@ def new_req_logits_processor( Returns: `Callable` request logits processor, or None """ - target_token: Optional[Any] = params.extra_args and params.extra_args.get( + target_token: Any | None = params.extra_args and params.extra_args.get( "target_token" ) if target_token is None: diff --git a/tests/v1/metrics/test_engine_logger_apis.py b/tests/v1/metrics/test_engine_logger_apis.py index bf780b1f36ad..2e243c23cbf9 100644 --- a/tests/v1/metrics/test_engine_logger_apis.py +++ b/tests/v1/metrics/test_engine_logger_apis.py @@ -4,33 +4,13 @@ import pytest +from tests.plugins.vllm_add_dummy_stat_logger.dummy_stat_logger.dummy_stat_logger import ( # noqa E501 + DummyStatLogger, +) from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger -class DummyStatLogger: - """ - A dummy stat logger for testing purposes. - Implements the minimal interface expected by StatLoggerManager. - """ - - def __init__(self, vllm_config, engine_idx): - self.vllm_config = vllm_config - self.engine_idx = engine_idx - self.recorded = [] - self.logged = False - self.engine_initialized = False - - def record(self, scheduler_stats, iteration_stats, engine_idx): - self.recorded.append((scheduler_stats, iteration_stats, engine_idx)) - - def log(self): - self.logged = True - - def log_engine_initialized(self): - self.engine_initialized = True - - @pytest.fixture def log_stats_enabled_engine_args(): """ @@ -54,7 +34,7 @@ async def test_async_llm_replace_default_loggers(log_stats_enabled_engine_args): engine = AsyncLLM.from_engine_args( log_stats_enabled_engine_args, stat_loggers=[RayPrometheusStatLogger] ) - assert isinstance(engine.logger_manager.prometheus_logger, RayPrometheusStatLogger) + assert isinstance(engine.logger_manager.stat_loggers[0], RayPrometheusStatLogger) engine.shutdown() @@ -73,9 +53,11 @@ async def test_async_llm_add_to_default_loggers(log_stats_enabled_engine_args): disabled_log_engine_args, stat_loggers=[DummyStatLogger] ) - assert len(engine.logger_manager.per_engine_logger_dict[0]) == 1 + assert len(engine.logger_manager.stat_loggers) == 2 + assert len(engine.logger_manager.stat_loggers[0].per_engine_stat_loggers) == 1 assert isinstance( - engine.logger_manager.per_engine_logger_dict[0][0], DummyStatLogger + engine.logger_manager.stat_loggers[0].per_engine_stat_loggers[0], + DummyStatLogger, ) # log_stats is still True, since custom stat loggers are used diff --git a/tests/v1/metrics/test_ray_metrics.py b/tests/v1/metrics/test_ray_metrics.py index 2cb5e6733b79..f08d9f684921 100644 --- a/tests/v1/metrics/test_ray_metrics.py +++ b/tests/v1/metrics/test_ray_metrics.py @@ -4,7 +4,7 @@ import pytest import ray -from vllm.config import ModelDType +from vllm.config.model import ModelDType from vllm.sampling_params import SamplingParams from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM from vllm.v1.metrics.ray_wrappers import RayPrometheusMetric, RayPrometheusStatLogger diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index bda430a080f6..6d4a1ecf78c8 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -2,12 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools +import math from collections.abc import Generator from typing import get_args import pytest import torch +from tests.utils import large_gpu_mark from tests.v1.sample.utils import ( BatchLogprobsComposition, BatchLogprobsSpecType, @@ -16,7 +18,8 @@ get_test_batch, ) from vllm import SamplingParams -from vllm.config import LogprobsMode +from vllm.config.model import LogprobsMode +from vllm.distributed import cleanup_dist_env_and_memory from ...conftest import HfRunner, VllmRunner @@ -459,7 +462,7 @@ def test_all_logprobs(example_prompts): results_logprobs_all = runner.llm.generate( example_prompts, sampling_params=sampling_params_logprobs_all ) - vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size() + vocab_size = runner.llm.llm_engine.model_config.get_vocab_size() for i in range(len(results_logprobs_all)): logprobs = results_logprobs_all[i].outputs[0].logprobs @@ -508,3 +511,94 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode): if logprobs_mode in ("raw_logits", "processed_logits"): assert positive_values > 0 del llm + + +@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode)) +@pytest.mark.parametrize( + "model_setup", + [ + pytest.param( + ( + "eagle", + "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", + ), + marks=large_gpu_mark(min_gb=32), + ), + ], +) +def test_spec_decode_logprobs( + logprobs_mode: LogprobsMode, + model_setup: tuple[str, str, str], + monkeypatch: pytest.MonkeyPatch, +): + """Spec decode logprobs should match those of the base model. + + Args: + logprobs_mode: logprobs mode. + model_setup: Spec decode method, base model name, and + draft model name. + """ + from vllm import LLM + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + prompt = "Hello world" + sampling_params = SamplingParams( + temperature=0, logprobs=3, max_tokens=10, ignore_eos=False + ) + method, model_name, spec_model_name = model_setup + max_model_len = 256 + + # Run base LLM. + ref_llm = LLM( + model=model_name, + max_logprobs=5, + max_model_len=max_model_len, + seed=42, + logprobs_mode=logprobs_mode, + gpu_memory_utilization=0.4, + ) + ref_results = ref_llm.generate([prompt], sampling_params) + # Collect logprobs outputs from reference LLM. + ref_logprobs = [] + for output in ref_results[0].outputs: + for logprobs in output.logprobs: + for token_id in logprobs: + ref_logprobs.append(logprobs[token_id]) + del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + # Run spec decode LLM. + spec_llm = LLM( + model_name, + speculative_config={ + "method": method, + "model": spec_model_name, + "num_speculative_tokens": 3, + "max_model_len": max_model_len, + }, + max_logprobs=5, + max_model_len=max_model_len, + seed=42, + logprobs_mode=logprobs_mode, + gpu_memory_utilization=0.4, + ) + spec_results = spec_llm.generate([prompt], sampling_params) + # Collect logprobs outputs from spec decode LLM. + spec_logprobs = [] + for output in spec_results[0].outputs: + for logprobs in output.logprobs: + for token_id in logprobs: + spec_logprobs.append(logprobs[token_id]) + del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + # Per-token logprobs are expected to be the same. + assert len(ref_logprobs) == len(spec_logprobs) + for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs): + assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3) + assert ref_logprob.rank == spec_logprob.rank + assert ref_logprob.decoded_token == spec_logprob.decoded_token diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 8df10f8c3afa..bf7726ebf907 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Any +from unittest.mock import Mock import pytest import torch @@ -11,6 +12,7 @@ from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler +from vllm.v1.sample.sampler import Sampler, SamplerOutput from vllm.v1.spec_decode.metadata import SpecDecodeMetadata DEVICE = current_platform.device_type @@ -18,13 +20,34 @@ @pytest.fixture def rejection_sampler(): - return RejectionSampler() + mock_sampler = Mock(spec=Sampler) + mock_sampler.logprobs_mode = "raw_logprobs" + return RejectionSampler(mock_sampler) + + +def mock_sampler_output( + rejection_sampler: RejectionSampler, bonus_token_ids: torch.Tensor +): + rejection_sampler.sampler.return_value = SamplerOutput( + sampled_token_ids=bonus_token_ids, logprobs_tensors=None + ) + + +def create_spec_decode_metadata( + spec_tokens: list[list[int]], logits: torch.Tensor +) -> SpecDecodeMetadata: + metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device) + metadata.target_logits_indices = torch.arange(logits.shape[0]) + # Output bonus token ids are mocked, so the bonus logit indices should + # be empty. + metadata.bonus_logits_indices = torch.empty(0, dtype=torch.int32) + return metadata def create_logits_tensor( output_token_ids: list[list[int]], vocab_size: int = 100, - token_idx_to_override: Optional[int] = None, + token_idx_to_override: int | None = None, ) -> torch.Tensor: """Helper function to create logits tensor that will produce desired token ids on argmax""" @@ -43,18 +66,18 @@ def create_logits_tensor( def create_sampling_metadata( all_greedy: bool, - output_token_ids: Optional[list[list[int]]] = None, - prompt_token_ids: Optional[torch.Tensor] = None, - spec_token_ids: Optional[torch.Tensor] = None, - temperature: Optional[torch.Tensor] = None, - top_k: Optional[torch.Tensor] = None, - top_p: Optional[torch.Tensor] = None, - generators: Optional[dict[int, Any]] = None, - frequency_penalties: Optional[list[float]] = None, - presence_penalties: Optional[list[float]] = None, - repetition_penalties: Optional[list[float]] = None, - bad_words_token_ids: Optional[dict[int, list[list[int]]]] = None, - allowed_token_ids_mask: Optional[torch.Tensor] = None, + output_token_ids: list[list[int]] | None = None, + prompt_token_ids: torch.Tensor | None = None, + spec_token_ids: torch.Tensor | None = None, + temperature: torch.Tensor | None = None, + top_k: torch.Tensor | None = None, + top_p: torch.Tensor | None = None, + generators: dict[int, Any] | None = None, + frequency_penalties: list[float] | None = None, + presence_penalties: list[float] | None = None, + repetition_penalties: list[float] | None = None, + bad_words_token_ids: dict[int, list[list[int]]] | None = None, + allowed_token_ids_mask: torch.Tensor | None = None, ) -> SamplingMetadata: """Create a v1 sampling metadata object with all_greedy set to the given value. Either all greedy or all random sampling @@ -111,19 +134,17 @@ def test_perfect_match(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - spec_tokens, device=logits.device - ) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_early_mismatch(rejection_sampler): @@ -134,15 +155,13 @@ def test_early_mismatch(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - spec_tokens, device=logits.device - ) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor( @@ -150,7 +169,7 @@ def test_early_mismatch(rejection_sampler): dtype=torch.int, device=logits.device, ) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_multiple_sequences(rejection_sampler): @@ -163,21 +182,19 @@ def test_multiple_sequences(rejection_sampler): bonus_token_tensor = torch.tensor( [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device ) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - spec_tokens, device=logits.device - ) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor( [[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device ) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_single_token_sequence(rejection_sampler): @@ -188,19 +205,17 @@ def test_single_token_sequence(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - spec_tokens, device=logits.device - ) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_empty_sequence(rejection_sampler): @@ -211,19 +226,17 @@ def test_empty_sequence(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - spec_tokens, device=logits.device - ) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor([[5]], dtype=torch.int, device=logits.device) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_multiple_mismatches(rejection_sampler): @@ -236,15 +249,13 @@ def test_multiple_mismatches(rejection_sampler): bonus_token_tensor = torch.tensor( [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device ) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - spec_tokens, device=logits.device - ) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor( @@ -255,7 +266,7 @@ def test_multiple_mismatches(rejection_sampler): dtype=torch.int, device=logits.device, ) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) @pytest.mark.parametrize( @@ -277,19 +288,17 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, expec bonus_token_tensor = torch.tensor( [tokens[-1] for tokens in output_tokens], device=logits.device ) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - spec_tokens, device=logits.device - ) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device) - assert torch.equal(output, expected_tensor) + assert torch.equal(output.sampled_token_ids, expected_tensor) ########################### Tests for Random Sampling ################### @@ -331,18 +340,19 @@ def test_deterministic_when_seeded( sampling_metadata = create_sampling_metadata( all_greedy=False, temperature=temperature, generators=seeded_seqs ) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids.tolist(), device=DEVICE + spec_decode_metadata = create_spec_decode_metadata( + draft_token_ids.tolist(), target_logits ) + + mock_sampler_output(rejection_sampler, bonus_token_ids) rep_result = rejection_sampler( spec_decode_metadata, - draft_probs=draft_probs, - target_logits=target_logits, - bonus_token_ids=bonus_token_ids, + draft_probs=None, + logits=target_logits, sampling_metadata=sampling_metadata, ) - results.append(rep_result) + results.append(rep_result.sampled_token_ids) for i in range(batch_size): if seeded_mask[i]: @@ -460,7 +470,9 @@ def estimate_rejection_sampling_pdf( Returns: Estimated probability distribution of the output tokens. """ - rejection_sampler = RejectionSampler() + mock_sampler = Mock(spec=Sampler) + mock_sampler.logprobs_mode = "raw_logprobs" + rejection_sampler = RejectionSampler(mock_sampler) num_tokens = num_samples * k # Repeat draft probs num_samples * k times. draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1) @@ -483,17 +495,18 @@ def estimate_rejection_sampling_pdf( sampling_metadata = create_sampling_metadata( all_greedy=False, temperature=temperature ) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids.tolist(), device=bonus_token_ids.device + spec_decode_metadata = create_spec_decode_metadata( + draft_token_ids.tolist(), target_logits ) - output_token_ids = rejection_sampler( + + mock_sampler_output(rejection_sampler, bonus_token_ids) + sampler_output = rejection_sampler( spec_decode_metadata, draft_probs=draft_probs, - target_logits=target_logits, - bonus_token_ids=bonus_token_ids, + logits=target_logits, sampling_metadata=sampling_metadata, ) - output_token_ids = output_token_ids[:, :-1].flatten() + output_token_ids = sampler_output.sampled_token_ids[:, :-1].flatten() hist = torch.histogram( output_token_ids.to(dtype=torch.float, device="cpu"), @@ -532,22 +545,19 @@ def _test_masked_logits( bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE) # Create spec decode metadata - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids, - device=DEVICE, - ) + spec_decode_metadata = create_spec_decode_metadata(draft_token_ids, target_logits) # Run rejection sampling - output_token_ids = rejection_sampler( + mock_sampler_output(rejection_sampler, bonus_token_ids) + output = rejection_sampler( spec_decode_metadata, draft_probs=draft_probs, - target_logits=target_logits, - bonus_token_ids=bonus_token_ids, + logits=target_logits, sampling_metadata=sampling_metadata, ) # Remove bonus tokens and reshape - output_token_ids = output_token_ids[:, :-1].flatten().tolist() + output_token_ids = output.sampled_token_ids[:, :-1].flatten().tolist() # Check that all sampled tokens are within the unmasked indices. for i in range(num_tokens): @@ -665,11 +675,11 @@ def test_frequency_penalties(rejection_sampler): spec_decode_metadata = SpecDecodeMetadata.make_dummy( spec_tokens, device=logits.device ) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) expected = torch.tensor( @@ -677,7 +687,7 @@ def test_frequency_penalties(rejection_sampler): dtype=torch.int, device=logits.device, ) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_bad_words(rejection_sampler): @@ -707,14 +717,12 @@ def test_bad_words(rejection_sampler): bonus_token_tensor = torch.tensor( [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device ) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - spec_tokens, device=logits.device - ) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) @@ -723,7 +731,7 @@ def test_bad_words(rejection_sampler): dtype=torch.int, device=logits.device, ) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) def test_allowed_token_ids(rejection_sampler): @@ -756,14 +764,12 @@ def test_allowed_token_ids(rejection_sampler): bonus_token_tensor = torch.tensor( [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device ) - spec_decode_metadata = SpecDecodeMetadata.make_dummy( - spec_tokens, device=logits.device - ) + spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits) + mock_sampler_output(rejection_sampler, bonus_token_tensor) output = rejection_sampler( spec_decode_metadata, draft_probs=None, - target_logits=logits, - bonus_token_ids=bonus_token_tensor, + logits=logits, sampling_metadata=metadata, ) @@ -772,4 +778,4 @@ def test_allowed_token_ids(rejection_sampler): dtype=torch.int, device=logits.device, ) - assert torch.equal(output, expected) + assert torch.equal(output.sampled_token_ids, expected) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index edc6acae848a..51f2bf5e753c 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -7,7 +7,8 @@ from tests.v1.sample.utils import create_allowed_token_ids from vllm.platforms import current_platform -from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.utils.torch_utils import make_tensor_with_pad from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py index bdde28fe0342..915b9957031d 100644 --- a/tests/v1/sample/test_sampling_params_e2e.py +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -5,16 +5,13 @@ from vllm import LLM, SamplingParams -MODEL = "meta-llama/Llama-3.2-1B" +MODEL = "hmellor/tiny-random-LlamaForCausalLM" PROMPT = "Hello my name is Robert and I" @pytest.fixture(scope="module") def llm() -> LLM: - # Disable prefix caching so that we can test prompt logprobs. - # TODO remove this after https://github.com/vllm-project/vllm/pull/13949 - # is merged - return LLM(MODEL, enforce_eager=True, enable_prefix_caching=False) + return LLM(MODEL, enforce_eager=True) def test_n_gt_1(llm): diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index c70cbebe22ca..f50ef6102204 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -5,20 +5,13 @@ from torch import Generator from vllm.platforms import current_platform -from vllm.v1.sample.ops.topk_topp_sampler import ( - apply_top_k_top_p, - is_flashinfer_available, -) +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p DEVICE = current_platform.device_type BATCH_SIZE = 1024 VOCAB_SIZE = 128 * 1024 -FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available -if is_flashinfer_available: - from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs - @pytest.fixture(autouse=True) def reset_default_device(): @@ -65,6 +58,14 @@ def test_flashinfer_sampler(): sampling results due to randomness), so we will compare the probability renormed consequently by top-k and then top-p of FlashInfer implementation. """ + try: + from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs + + is_flashinfer_available = True + except ImportError: + is_flashinfer_available = False + + FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available if not FLASHINFER_ENABLED: pytest.skip("FlashInfer not installed or not available on this platform.") diff --git a/tests/v1/sample/utils.py b/tests/v1/sample/utils.py index b1c63327b852..a0abb3b4c6ce 100644 --- a/tests/v1/sample/utils.py +++ b/tests/v1/sample/utils.py @@ -3,13 +3,13 @@ from collections.abc import Iterator from enum import Enum -from typing import NamedTuple, Optional +from typing import NamedTuple import regex as re import torch from vllm import CompletionOutput -from vllm.utils import make_tensor_with_pad +from vllm.utils.torch_utils import make_tensor_with_pad from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor from vllm.v1.sample.metadata import SamplingMetadata @@ -23,7 +23,7 @@ class BatchLogprobsComposition(Enum): SAMPLE_PROMPT = 3 -BatchLogprobsSpecType = list[tuple[Optional[int], Optional[int]]] +BatchLogprobsSpecType = list[tuple[int | None, int | None]] def get_test_batch( @@ -222,8 +222,8 @@ def create_allowed_token_ids( vocab_size: int, num_allowed_token_ids: int, device: torch.device, -) -> Optional[torch.Tensor]: - mask: Optional[torch.Tensor] = None +) -> torch.Tensor | None: + mask: torch.Tensor | None = None for i in range(batch_size): if i % 2 == 1: continue diff --git a/tests/v1/shutdown/test_delete.py b/tests/v1/shutdown/test_delete.py index d94357827864..ee04dfad3906 100644 --- a/tests/v1/shutdown/test_delete.py +++ b/tests/v1/shutdown/test_delete.py @@ -12,10 +12,10 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.sampling_params import RequestOutputKind -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.v1.engine.async_llm import AsyncLLM -MODELS = ["meta-llama/Llama-3.2-1B"] +MODELS = ["hmellor/tiny-random-LlamaForCausalLM"] @pytest.mark.asyncio diff --git a/tests/v1/shutdown/test_forward_error.py b/tests/v1/shutdown/test_forward_error.py index 383348e88540..a751b2d919e1 100644 --- a/tests/v1/shutdown/test_forward_error.py +++ b/tests/v1/shutdown/test_forward_error.py @@ -14,11 +14,11 @@ from vllm import LLM, AsyncEngineArgs, SamplingParams from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.exceptions import EngineDeadError -MODELS = ["meta-llama/Llama-3.2-1B"] +MODELS = ["hmellor/tiny-random-LlamaForCausalLM"] def evil_forward(self, *args, **kwargs): diff --git a/tests/v1/shutdown/test_startup_error.py b/tests/v1/shutdown/test_startup_error.py index 019c0c4d7cf0..c1594cc2e8b7 100644 --- a/tests/v1/shutdown/test_startup_error.py +++ b/tests/v1/shutdown/test_startup_error.py @@ -13,10 +13,10 @@ from vllm.distributed import get_tensor_model_parallel_rank from vllm.engine.arg_utils import AsyncEngineArgs from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.v1.engine.async_llm import AsyncLLM -MODELS = ["meta-llama/Llama-3.2-1B"] +MODELS = ["hmellor/tiny-random-LlamaForCausalLM"] def evil_method(self, *args, **kwargs): @@ -76,8 +76,10 @@ def test_llm_startup_error( Test profiling (forward()) and load weights failures. TODO(andy) - LLM without multiprocessing. """ - if model != "meta-llama/Llama-3.2-1B": - pytest.skip(reason="Only test meta-llama/Llama-3.2-1B") + # Skip non-Llama models since we monkeypatch LlamaForCausalLM specifically. + # If MODELS list grows, each architecture needs its own test variant. + if model != "JackFram/llama-68m": + pytest.skip(reason="Only test JackFram/llama-68m") if cuda_device_count_stateless() < tensor_parallel_size: pytest.skip(reason="Not enough CUDA devices") diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 4c490f2188aa..47d05a20a65d 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional from unittest import mock import pytest @@ -12,7 +11,7 @@ BatchSpec, create_common_attn_metadata, create_standard_kv_cache_spec, - get_attention_backend, + try_get_attention_backend, ) from vllm.attention.backends.registry import _Backend from vllm.config import ( @@ -39,7 +38,7 @@ def _create_proposer( method: str, num_speculative_tokens: int, - speculative_token_tree: Optional[list[tuple[int, ...]]] = None, + speculative_token_tree: list[tuple[int, ...]] | None = None, ) -> EagleProposer: model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100) @@ -535,11 +534,11 @@ def create_deterministic_logits(token_ids): sampling_metadata = mock.MagicMock() if attn_backend == "FLASH_ATTN": - attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN) elif attn_backend == "TRITON_ATTN": - attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TRITON_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TRITON_ATTN) elif attn_backend == "TREE_ATTN": - attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN) else: raise ValueError(f"Unsupported attention backend: {attn_backend}") @@ -674,7 +673,7 @@ def create_deterministic_logits(token_ids, k: int): proposer.attn_layer_names = ["layer.0"] # Get the tree attention metadata builder. - attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), layer_names=proposer.attn_layer_names, diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index d7d9ef07e46c..9ca7cf9e3e0e 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -10,7 +10,7 @@ BatchSpec, create_common_attn_metadata, create_standard_kv_cache_spec, - get_attention_backend, + try_get_attention_backend, ) from vllm.attention.backends.registry import _Backend from vllm.config import ( @@ -177,7 +177,7 @@ def create_deterministic_logits(batch_size, vocab_size, token_offset): sampling_metadata = mock.MagicMock() # Setup attention metadata - attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN) + attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN) attn_metadata_builder = attn_metadata_builder_cls( kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index a46e8e3ec755..b365e75d5514 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -2,14 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math -from typing import Optional import torch from tests.v1.attention.utils import ( create_standard_kv_cache_spec, create_vllm_config, - get_attention_backend, + try_get_attention_backend, ) from vllm.attention.backends.registry import _Backend from vllm.config import ParallelConfig, SpeculativeConfig @@ -37,7 +36,7 @@ def forward_attention( slot_mapping: torch.Tensor, seqlen_k: int, backend: _Backend, - spec_token_tree: Optional[str] = None, + spec_token_tree: str | None = None, num_spec_tokens: int = 0, ) -> torch.Tensor: batch_size, q_len, num_heads, dim_per_head = q.shape @@ -63,7 +62,7 @@ def forward_attention( # Build common metadata. model_name = "meta-llama/Meta-Llama-3-8B" - builder_cls, impl_cls = get_attention_backend(backend) + builder_cls, impl_cls = try_get_attention_backend(backend) vllm_config = create_vllm_config(model_name=model_name, max_model_len=max(seq_lens)) if spec_token_tree is not None: # Create speculative config if token tree is specified. diff --git a/tests/v1/structured_output/test_gptoss_structural_tags.py b/tests/v1/structured_output/test_gptoss_structural_tags.py new file mode 100644 index 000000000000..f0feabfb99ab --- /dev/null +++ b/tests/v1/structured_output/test_gptoss_structural_tags.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Unit tests for GPT-OSS structural tag support in reasoning (PR #25515).""" + +import json +from unittest.mock import Mock + +import pytest + +from vllm.entrypoints.tool_server import ToolServer +from vllm.reasoning.gptoss_reasoning_parser import ( + GptOssReasoningParser, + from_builtin_tool_to_tag, + no_func_reaonsing_tag, + tag_with_builtin_funcs, +) + + +class TestGptOssReasoningParser: + """Test cases for GptOssReasoningParser structural tag functionality.""" + + @pytest.fixture + def mock_tokenizer(self): + """Create a mock tokenizer for testing.""" + tokenizer = Mock() + tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5]) + return tokenizer + + @pytest.fixture + def reasoning_parser(self, mock_tokenizer): + """Create a GptOssReasoningParser instance.""" + return GptOssReasoningParser(mock_tokenizer) + + @pytest.fixture + def mock_tool_server_empty(self): + """Create a mock ToolServer with no tools.""" + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock(return_value=False) + return tool_server + + @pytest.fixture + def mock_tool_server_with_browser(self): + """Create a mock ToolServer with browser tool.""" + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock(side_effect=lambda tool: tool == "browser") + return tool_server + + @pytest.fixture + def mock_tool_server_with_all_tools(self): + """Create a mock ToolServer with all builtin tools.""" + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock( + side_effect=lambda tool: tool in ["browser", "python", "container"] + ) + return tool_server + + def test_prepare_structured_tag_no_tool_server(self, reasoning_parser): + """Test prepare_structured_tag with no tool server.""" + result = reasoning_parser.prepare_structured_tag(None, None) + expected = json.dumps(no_func_reaonsing_tag) + + assert result == expected + + # Verify the structure is correct + parsed = json.loads(result) + assert parsed["type"] == "structural_tag" + assert parsed["format"]["type"] == "triggered_tags" + assert len(parsed["format"]["tags"]) == 1 + assert parsed["format"]["tags"][0]["begin"] == "<|channel|>analysis<|message|>" + assert parsed["format"]["triggers"] == ["<|channel|>analysis"] + + def test_prepare_structured_tag_with_all_tools( + self, reasoning_parser, mock_tool_server_with_all_tools + ): + """Test prepare_structured_tag with all builtin tools.""" + result = reasoning_parser.prepare_structured_tag( + None, mock_tool_server_with_all_tools + ) + parsed = json.loads(result) + + # Should have analysis tag + tags for all 3 tools (2 tags each) + assert len(parsed["format"]["tags"]) == 7 # 1 analysis + 6 tool tags + + # Check all tool tags are present + tag_begins = [tag["begin"] for tag in parsed["format"]["tags"]] + for tool in ["browser", "python", "container"]: + assert f"<|channel|>commentary to={tool}" in tag_begins + assert f"<|channel|>analysis to={tool}" in tag_begins + + def test_prepare_structured_tag_with_original_tag(self, reasoning_parser): + """Test prepare_structured_tag when original_tag is provided.""" + original_tag = '{"custom": "tag"}' + result = reasoning_parser.prepare_structured_tag(original_tag, None) + + # Should return the original tag unchanged + assert result == original_tag + + def test_from_builtin_tool_to_tag(self): + """Test from_builtin_tool_to_tag function.""" + tags = from_builtin_tool_to_tag("python") + + assert len(tags) == 2 + assert tags[0]["begin"] == "<|channel|>commentary to=python" + assert tags[0]["content"]["type"] == "any_text" + assert tags[0]["end"] == "<|end|>" + + assert tags[1]["begin"] == "<|channel|>analysis to=python" + assert tags[1]["content"]["type"] == "any_text" + assert tags[1]["end"] == "<|end|>" + + def test_tag_with_builtin_funcs(self): + """Test tag_with_builtin_funcs function.""" + builtin_tools = ["browser", "python"] + result = tag_with_builtin_funcs(no_func_reaonsing_tag, builtin_tools) + + assert result["type"] == "structural_tag" + # Should have original analysis tag + 2 tags per tool + assert len(result["format"]["tags"]) == 5 # 1 + 2*2 + + # Should have added commentary trigger + assert "<|channel|>commentary to=" in result["format"]["triggers"] + assert "<|channel|>analysis" in result["format"]["triggers"] + + def test_tag_structure_invariants(self): + """Test that the basic tag structure follows expected format.""" + # Test the base no_func_reaonsing_tag structure + assert no_func_reaonsing_tag["type"] == "structural_tag" + assert no_func_reaonsing_tag["format"]["type"] == "triggered_tags" + assert no_func_reaonsing_tag["format"]["stop_after_first"] is False + + # Verify analysis tag structure + analysis_tag = no_func_reaonsing_tag["format"]["tags"][0] + assert analysis_tag["begin"] == "<|channel|>analysis<|message|>" + assert analysis_tag["content"]["type"] == "any_text" + assert analysis_tag["end"] == "<|end|>" + + def test_json_serialization_valid( + self, reasoning_parser, mock_tool_server_with_all_tools + ): + """Test that all generated tags produce valid JSON.""" + # Test with no tool server + result1 = reasoning_parser.prepare_structured_tag(None, None) + json.loads(result1) # Should not raise + + # Test with empty tool server + empty_server = Mock(spec=ToolServer) + empty_server.has_tool = Mock(return_value=False) + result2 = reasoning_parser.prepare_structured_tag(None, empty_server) + json.loads(result2) # Should not raise + + # Test with tools + result3 = reasoning_parser.prepare_structured_tag( + None, mock_tool_server_with_all_tools + ) + json.loads(result3) # Should not raise + + @pytest.mark.parametrize("tool_name", ["browser", "python", "container"]) + def test_single_tool_integration(self, reasoning_parser, tool_name): + """Test integration with individual tools.""" + tool_server = Mock(spec=ToolServer) + tool_server.has_tool = Mock(side_effect=lambda tool: tool == tool_name) + + result = reasoning_parser.prepare_structured_tag(None, tool_server) + parsed = json.loads(result) + + # Should have 1 analysis + 2 tool-specific tags + assert len(parsed["format"]["tags"]) == 3 + + tag_begins = [tag["begin"] for tag in parsed["format"]["tags"]] + assert f"<|channel|>commentary to={tool_name}" in tag_begins + assert f"<|channel|>analysis to={tool_name}" in tag_begins diff --git a/tests/v1/structured_output/test_reasoning_structured_output.py b/tests/v1/structured_output/test_reasoning_structured_output.py new file mode 100644 index 000000000000..70047a993c3f --- /dev/null +++ b/tests/v1/structured_output/test_reasoning_structured_output.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Unit tests for reasoning-aware structured output functionality (PR #25515).""" + +from unittest.mock import Mock + +import pytest + +from vllm.config import ModelConfig, SchedulerConfig, VllmConfig +from vllm.reasoning import ReasoningParser +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager + + +class TestReasoningStructuredOutput: + """Test reasoning-aware structured output functionality.""" + + @pytest.fixture + def mock_model_config(self): + """Create a mock ModelConfig.""" + config = Mock(spec=ModelConfig) + config.skip_tokenizer_init = True # Skip tokenizer init to avoid network calls + config.get_vocab_size = Mock(return_value=50000) + # Add missing runner_type attribute that tokenizer initialization expects + config.runner_type = "generate" + # Add other attributes that tokenizer initialization might need + config.tokenizer = "test-tokenizer" + config.tokenizer_mode = "auto" + config.trust_remote_code = False + config.tokenizer_revision = None + return config + + @pytest.fixture + def mock_scheduler_config(self): + """Create a mock SchedulerConfig.""" + config = Mock(spec=SchedulerConfig) + config.max_num_seqs = 128 + return config + + @pytest.fixture + def mock_vllm_config(self, mock_model_config, mock_scheduler_config): + """Create a mock VllmConfig.""" + config = Mock(spec=VllmConfig) + config.model_config = mock_model_config + config.scheduler_config = mock_scheduler_config + config.structured_outputs_config = Mock() + config.structured_outputs_config.reasoning_parser = None + config.structured_outputs_config.enable_in_reasoning = False + config.speculative_config = None + return config + + @pytest.fixture + def mock_reasoning_parser(self): + """Create a mock ReasoningParser.""" + parser = Mock(spec=ReasoningParser) + parser.is_reasoning_end = Mock(return_value=False) + return parser + + @pytest.fixture + def mock_request_with_structured_output(self): + """Create a mock request with structured output.""" + request = Mock(spec=Request) + request.structured_output_request = Mock() + request.structured_output_request.reasoning_ended = None + request.structured_output_request.grammar = Mock() + request.structured_output_request.grammar.is_terminated = Mock( + return_value=False + ) + request.use_structured_output = True + request.prompt_token_ids = [1, 2, 3, 4, 5] + request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8] + return request + + def test_should_fill_bitmask_with_enable_in_reasoning( + self, mock_vllm_config, mock_request_with_structured_output + ): + """Test should_fill_bitmask when enable_in_reasoning is True.""" + # Enable enable_in_reasoning + mock_vllm_config.structured_outputs_config.enable_in_reasoning = True + + manager = StructuredOutputManager(mock_vllm_config) + + # Should always return True when enable_in_reasoning is enabled + result = manager.should_fill_bitmask(mock_request_with_structured_output) + assert result is True + + def test_should_fill_bitmask_without_enable_in_reasoning( + self, + mock_vllm_config, + mock_request_with_structured_output, + mock_reasoning_parser, + ): + """Test should_fill_bitmask when enable_in_reasoning is False.""" + # Keep enable_in_reasoning as False (default) + config = mock_vllm_config.structured_outputs_config + assert config.enable_in_reasoning is False + + manager = StructuredOutputManager(mock_vllm_config) + manager.reasoner = mock_reasoning_parser + + # Mock reasoning not ended + mock_reasoning_parser.is_reasoning_end.return_value = False + + result = manager.should_fill_bitmask(mock_request_with_structured_output) + + # Should set reasoning_ended and return its value + assert ( + mock_request_with_structured_output.structured_output_request.reasoning_ended + is False + ) + assert result is False + + def test_should_fill_bitmask_no_reasoner( + self, mock_vllm_config, mock_request_with_structured_output + ): + """Test should_fill_bitmask when no reasoner is configured.""" + manager = StructuredOutputManager(mock_vllm_config) + manager.reasoner = None + + result = manager.should_fill_bitmask(mock_request_with_structured_output) + + # Should default to True when no reasoner + assert result is True + + def test_should_advance_with_enable_in_reasoning( + self, + mock_vllm_config, + mock_request_with_structured_output, + mock_reasoning_parser, + ): + """Test should_advance when enable_in_reasoning is True.""" + # Enable enable_in_reasoning + mock_vllm_config.structured_outputs_config.enable_in_reasoning = True + + manager = StructuredOutputManager(mock_vllm_config) + manager.reasoner = mock_reasoning_parser + + # Should always return True when enable_in_reasoning is enabled + result = manager.should_advance(mock_request_with_structured_output) + assert result is True + + def test_should_advance_reasoning_not_ended( + self, + mock_vllm_config, + mock_request_with_structured_output, + mock_reasoning_parser, + ): + """Test should_advance when reasoning has not ended.""" + manager = StructuredOutputManager(mock_vllm_config) + manager.reasoner = mock_reasoning_parser + + # Set reasoning as not ended + ( + mock_request_with_structured_output.structured_output_request + ).reasoning_ended = False + mock_reasoning_parser.is_reasoning_end.return_value = False + + result = manager.should_advance(mock_request_with_structured_output) + + # Should return False since reasoning hasn't ended + assert result is False + + def test_should_advance_reasoning_just_ended( + self, + mock_vllm_config, + mock_request_with_structured_output, + mock_reasoning_parser, + ): + """Test should_advance when reasoning ends in current step.""" + manager = StructuredOutputManager(mock_vllm_config) + manager.reasoner = mock_reasoning_parser + + # Set reasoning as not ended initially, but ends in this step + ( + mock_request_with_structured_output.structured_output_request + ).reasoning_ended = False + mock_reasoning_parser.is_reasoning_end.return_value = True + + result = manager.should_advance(mock_request_with_structured_output) + + # Should set reasoning_ended to True but return False for this step + assert ( + mock_request_with_structured_output.structured_output_request.reasoning_ended + is True + ) + assert result is False + + def test_should_advance_reasoning_already_ended( + self, + mock_vllm_config, + mock_request_with_structured_output, + mock_reasoning_parser, + ): + """Test should_advance when reasoning has already ended.""" + manager = StructuredOutputManager(mock_vllm_config) + manager.reasoner = mock_reasoning_parser + + # Set reasoning as already ended + ( + mock_request_with_structured_output.structured_output_request + ).reasoning_ended = True + + result = manager.should_advance(mock_request_with_structured_output) + + # Should return True since reasoning has ended + assert result is True diff --git a/tests/v1/structured_output/test_utils.py b/tests/v1/structured_output/test_utils.py index b285658af3d1..513a21dd6bb3 100644 --- a/tests/v1/structured_output/test_utils.py +++ b/tests/v1/structured_output/test_utils.py @@ -13,7 +13,7 @@ @pytest.fixture def unsupported_string_schemas(): return [ - {"type": "string", "format": "email"}, + {"type": "string", "format": "non_existing_format"}, ] @@ -58,6 +58,7 @@ def supported_schema(): "properties": { "name": {"type": "string"}, "age": {"type": "integer"}, + "email": {"type": "string", "format": "email"}, "status": {"type": "string"}, "scores": {"type": "array", "items": {"type": "number"}}, "car_type": {"type": "string", "enum": ["sedan", "suv", "truck"]}, diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index a306a2b040d3..00749c5415c8 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import UserDict from dataclasses import dataclass -from typing import Optional import msgspec import numpy as np @@ -100,7 +99,7 @@ def test_encode_decode(monkeypatch: pytest.MonkeyPatch): class MyRequest(msgspec.Struct): - mm: Optional[list[MultiModalKwargsItems]] + mm: list[MultiModalKwargsItems] | None def test_multimodal_kwargs(): diff --git a/tests/v1/tpu/test_basic.py b/tests/v1/tpu/test_basic.py index f3495b00d3d4..0d53a02476fa 100644 --- a/tests/v1/tpu/test_basic.py +++ b/tests/v1/tpu/test_basic.py @@ -5,8 +5,6 @@ Run `pytest tests/v1/tpu/test_basic.py`. """ -from __future__ import annotations - from typing import TYPE_CHECKING import pytest @@ -16,6 +14,8 @@ if TYPE_CHECKING: from tests.conftest import VllmRunner +else: + VllmRunner = object MODELS = [ "Qwen/Qwen2.5-1.5B-Instruct", diff --git a/tests/v1/tpu/test_perf.py b/tests/v1/tpu/test_perf.py index b7b6835c40cc..e230491cddb0 100644 --- a/tests/v1/tpu/test_perf.py +++ b/tests/v1/tpu/test_perf.py @@ -5,8 +5,6 @@ Run `pytest tests/v1/tpu/test_perf.py`. """ -from __future__ import annotations - import time from dataclasses import dataclass from typing import TYPE_CHECKING @@ -20,6 +18,8 @@ if TYPE_CHECKING: from tests.conftest import VllmRunner +else: + VllmRunner = object @dataclass diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py index c2fc24442c7c..c6634395bb16 100644 --- a/tests/v1/tpu/test_topk_topp_sampler.py +++ b/tests/v1/tpu/test_topk_topp_sampler.py @@ -8,10 +8,7 @@ from vllm.platforms import current_platform from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p - -# isort: off from vllm.v1.sample.tpu.sampler import apply_top_k_top_p as apply_top_k_top_p_tpu -# isort: on if not current_platform.is_tpu(): pytest.skip("This test needs a TPU.", allow_module_level=True) diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index df9fcdc37fa3..1aa0709696c4 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -13,7 +13,7 @@ ) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.utils import GiB_bytes +from vllm.utils.mem_constants import GiB_bytes from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.worker.tpu_model_runner import ( @@ -89,10 +89,10 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -168,10 +168,10 @@ def test_update_states_request_finished(model_runner): total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids={req_id}, free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -198,10 +198,10 @@ def test_update_states_request_resumed(model_runner): total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -225,10 +225,10 @@ def test_update_states_request_resumed(model_runner): total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -256,10 +256,10 @@ def test_update_states_no_changes(model_runner): total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -291,10 +291,10 @@ def test_update_states_request_unscheduled(model_runner): total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) diff --git a/tests/v1/tracing/test_tracing.py b/tests/v1/tracing/test_tracing.py index 505da4163143..11d9d18ead7d 100644 --- a/tests/v1/tracing/test_tracing.py +++ b/tests/v1/tracing/test_tracing.py @@ -2,8 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # ruff: noqa # type: ignore -from __future__ import annotations - import threading from collections.abc import Iterable from concurrent import futures diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index e72bd43ff56e..6ea65c6944b0 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -3,7 +3,6 @@ import inspect from collections.abc import Sequence -from typing import Optional import numpy as np import pytest @@ -11,7 +10,8 @@ from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams -from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.utils.torch_utils import make_tensor_with_pad from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata @@ -241,6 +241,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): pin_memory=is_pin_memory_available(), vocab_size=1024, block_sizes=[1], + kernel_block_sizes=[1], ) reqs: list[CachedRequestState] = [] req_id_reqs = {} @@ -270,7 +271,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): reqs, req_ids_retained, input_batch.req_id_to_index, device=torch.device(device) ) - def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool: + def same(t1: torch.Tensor | None, t2: torch.Tensor | None) -> bool: return (t1 is None and t2 is None) or ( t1 is not None and t2 is not None and torch.allclose(t1, t2) ) @@ -335,6 +336,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis pin_memory=is_pin_memory_available(), vocab_size=1024, block_sizes=[1], + kernel_block_sizes=[1], ) ref_input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, @@ -344,6 +346,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis pin_memory=is_pin_memory_available(), vocab_size=1024, block_sizes=[1], + kernel_block_sizes=[1], ) reqs: list[CachedRequestState] = [] diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index ef2956bd3ec2..c2c34ee95ad5 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -21,7 +21,8 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams -from vllm.utils import GiB_bytes, update_environment_variables +from vllm.utils.mem_constants import GiB_bytes +from vllm.utils.system_utils import update_environment_variables from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.kv_cache_interface import ( @@ -68,6 +69,9 @@ def initialize_kv_cache(runner: GPUModelRunner): pin_memory=runner.pin_memory, vocab_size=runner.model_config.get_vocab_size(), block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size], + kernel_block_sizes=[ + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + ], ) runner.initialize_attn_backend(kv_cache_config) @@ -143,10 +147,10 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -209,10 +213,10 @@ def test_update_states_request_finished(model_runner, dist_init): total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids={req_id}, free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -241,10 +245,10 @@ def test_update_states_request_resumed(model_runner, dist_init): total_num_scheduled_tokens=0, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -257,6 +261,7 @@ def test_update_states_request_resumed(model_runner, dist_init): req_ids=[req_id], resumed_from_preemption=[False], new_token_ids=[[]], + resumed_req_token_ids=[None], new_block_ids=([[0]],), num_computed_tokens=[0], num_output_tokens=[0], @@ -269,10 +274,10 @@ def test_update_states_request_resumed(model_runner, dist_init): total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -362,10 +367,10 @@ def test_update_states_no_changes(model_runner, dist_init): total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -399,10 +404,10 @@ def test_update_states_request_unscheduled(model_runner, dist_init): total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[], finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, + structured_output_request_ids=[], grammar_bitmask=None, ) @@ -816,42 +821,231 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape # assert we are using FlashInfer - assert attn_shape[0] == num_blocks + assert attn_shape[0] % num_blocks == 0 + block_split_ratio = attn_shape[0] // num_blocks + + # use small blocks for testing to avoid memory issues + test_block_size = min(2, len(blocks0), len(blocks1)) + + # use non-overlapping blocks to avoid data contamination + # Split kernel blocks: first half for attention, second half for mamba + mid_point = num_blocks // 2 + + # attention uses kernel blocks from first half (mapped to logical blocks) + kv_blocks_for_attention = np.array([0, 1])[:test_block_size] + + # mamba uses kernel blocks from second half + kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size] + + # create small constant tensors for testing with corrected shapes + # attention: [block_size, ...] starting from dimension 2 + attn_constant_shape = attn_shape[2:] + conv_constant_shape = conv_shape[1:] + ssm_constant_shape = ssm_shape[1:] attn_blocks_constant = torch.full( - (len(blocks0), *attn_shape[1:]), device=DEVICE, fill_value=3.33 + (test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33 ) conv_blocks_constant = torch.full( - (len(blocks1), *conv_shape[1:]), device=DEVICE, fill_value=6.66 + (test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66 ) ssm_blocks_constant = torch.full( - (len(blocks1), *ssm_shape[1:]), device=DEVICE, fill_value=9.99 + (test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99 ) - # fill all attention blocks with constant + # Fill attention blocks with constants using kv block indices + kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio + for layer in [layer_0, layer_1]: - vllm_ctx[layer].kv_cache[0][blocks0, :] = ( - attn_blocks_constant.detach().clone() - ) + # attention: kv_cache[0][kernel_block_idx, kv_idx, ...] + for i, kernel_block in enumerate(kernel_blocks_for_attention): + vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i] - # fill all mamba blocks with constant + # fill mamba blocks with constants using kernel block indices for layer in [layer_2, layer_3, layer_4, layer_5]: - vllm_ctx[layer].kv_cache[0][0][blocks1, :] = ( - conv_blocks_constant.detach().clone() - ) - vllm_ctx[layer].kv_cache[0][1][blocks1, :] = ( - ssm_blocks_constant.detach().clone() - ) + # mamba: kv_cache[0][component][kernel_block_idx, ...] + for i, kv_block in enumerate(kv_blocks_for_mamba): + vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i] + vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i] # verify attention and mamba contents are correct for layer in [layer_0, layer_1]: - assert torch.equal( - vllm_ctx[layer].kv_cache[0][blocks0, :], attn_blocks_constant - ) + for i, kernel_block in enumerate(kernel_blocks_for_attention): + actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :] + expected = attn_blocks_constant[i] + + # Check K and V separately + assert torch.equal(actual_kv[0], expected) + assert torch.equal(actual_kv[1], expected) + for layer in [layer_2, layer_3, layer_4, layer_5]: - assert torch.equal( - vllm_ctx[layer].kv_cache[0][0][blocks1, :], conv_blocks_constant - ) - assert torch.equal( - vllm_ctx[layer].kv_cache[0][1][blocks1, :], ssm_blocks_constant - ) + for i, kv_block in enumerate(kv_blocks_for_mamba): + actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] + actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] + expected_conv = conv_blocks_constant[i] + expected_ssm = ssm_blocks_constant[i] + + assert torch.equal(actual_conv, expected_conv) + assert torch.equal(actual_ssm, expected_ssm) + + for layer in [layer_2, layer_3, layer_4, layer_5]: + for i, kv_block in enumerate(kv_blocks_for_mamba): + actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] + actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] + expected_conv = conv_blocks_constant[i] + expected_ssm = ssm_blocks_constant[i] + assert torch.equal(actual_conv, expected_conv) + assert torch.equal(actual_ssm, expected_ssm) + + +def test_hybrid_block_table_initialization(): + """Test hybrid block table with different kernel and kvcache_manager block + sizes.""" + from vllm.v1.worker.block_table import BlockTable + + # Test configuration: kvcache_manager block size = 32, + # kernel block size = 16 + block_size = 32 + kernel_block_sizes = [16] + max_num_reqs = 10 + max_num_blocks_per_req = 20 + max_num_batched_tokens = 512 + + block_table = BlockTable( + block_size=block_size, + max_num_reqs=max_num_reqs, + max_num_blocks_per_req=max_num_blocks_per_req, + max_num_batched_tokens=max_num_batched_tokens, + pin_memory=False, + device=torch.device(DEVICE), + kernel_block_size=kernel_block_sizes[0], + ) + + # Verify hybrid block configuration + assert block_table.use_hybrid_blocks is True + assert block_table.block_size == kernel_block_sizes[0] + assert block_table.blocks_per_kv_block == ( + block_size // kernel_block_sizes[0] + ) # Changed to use first element + + # Test block table conversion logic + # One kvcache_manager block should map to multiple kernel blocks + kvcache_manager_blocks = [0, 1, 2] + + # Verify that kvcache_manager blocks can be converted to kernel blocks + # and that block table operations work correctly. + req_index = 0 + block_table.append_row(kvcache_manager_blocks, req_index) + # Get expected kernel blocks from the implementation for verification. + expected_kernel_blocks = block_table._map_to_kernel_blocks( + np.array(kvcache_manager_blocks) + ) + # Verify block table state + assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks) + assert np.array_equal( + block_table.block_table.np[req_index, : len(expected_kernel_blocks)], + expected_kernel_blocks, + ) + + +def test_input_batch_with_kernel_block_sizes(): + """Test InputBatch initialization with kernel_block_sizes parameter.""" + max_num_reqs = 10 + max_model_len = 512 + max_num_batched_tokens = 512 + device = torch.device(DEVICE) + pin_memory = False + vocab_size = 50272 + + # Test with different kernel block sizes + block_sizes = [32, 64] + kernel_block_sizes = [16, 32] + + input_batch = InputBatch( + max_num_reqs=max_num_reqs, + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + device=device, + pin_memory=pin_memory, + vocab_size=vocab_size, + block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, + ) + + # Verify that block tables were created with kernel block sizes + assert len(input_batch.block_table.block_tables) == len(block_sizes) + + for i, (kv_size, kernel_size) in enumerate(zip(block_sizes, kernel_block_sizes)): + block_table = input_batch.block_table.block_tables[i] + if kv_size != kernel_size: + assert block_table.use_hybrid_blocks is True + assert block_table.block_size == kernel_size + else: + assert block_table.use_hybrid_blocks is False + assert block_table.block_size == kernel_size + + +def test_hybrid_cache_integration(model_runner, dist_init): + """Test hybrid cache architecture integration with GPUModelRunner.""" + # Create a new model runner with hybrid cache configuration + vllm_config = get_vllm_config() + + # Configure hybrid cache with different kvcache_manager block size + vllm_config.cache_config.block_size = 32 + + model_config = vllm_config.model_config + num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config) + head_size = model_config.get_head_size() + vllm_config.compilation_config.static_forward_context["layer.0"] = Attention( + num_heads, head_size, 0.1 + ) + + runner = GPUModelRunner(vllm_config, DEVICE) + + # Initialize KV cache with configuration + attn_spec = FullAttentionSpec( + block_size=16, # Use kernel block size directly + num_kv_heads=runner.model_config.get_num_kv_heads(runner.parallel_config), + head_size=runner.model_config.get_head_size(), + dtype=runner.kv_cache_dtype, + ) + tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS + kv_cache_config = KVCacheConfig( + num_blocks=NUM_BLOCKS, + kv_cache_tensors=[ + KVCacheTensor(size=tensor_size, shared_by=["layer.0"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec) + ], + ) + runner.kv_cache_config = kv_cache_config + + # Initialize input batch with kernel block sizes + runner.input_batch = InputBatch( + max_num_reqs=runner.max_num_reqs, + max_model_len=runner.max_model_len, + max_num_batched_tokens=runner.max_num_tokens, + device=runner.device, + pin_memory=runner.pin_memory, + vocab_size=runner.model_config.get_vocab_size(), + block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size], + kernel_block_sizes=[16], + ) # Use kernel block size + + runner.initialize_attn_backend(kv_cache_config) + + # Verify hybrid block table configuration + block_table = runner.input_batch.block_table.block_tables[0] + assert block_table.block_size == ( + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + ) + + # Test request processing with hybrid blocks + req_id = "hybrid_req_0" + scheduler_output = _schedule_new_request(req_id) + + # Update states should work with hybrid blocks + runner._update_states(scheduler_output) + assert _is_req_scheduled(runner, req_id) + assert _is_req_state_block_table_match(runner, req_id) diff --git a/tests/v1/worker/test_worker_memory_snapshot.py b/tests/v1/worker/test_worker_memory_snapshot.py index cbfb9a8dc0b6..66330127b5ec 100644 --- a/tests/v1/worker/test_worker_memory_snapshot.py +++ b/tests/v1/worker/test_worker_memory_snapshot.py @@ -4,19 +4,18 @@ import multiprocessing as mp import os import tempfile -from multiprocessing import Queue -from typing import Optional +from multiprocessing.queues import Queue from unittest.mock import patch import pytest import torch from vllm.engine.arg_utils import EngineArgs -from vllm.utils import MemorySnapshot +from vllm.utils.mem_utils import MemorySnapshot from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment # Global queue to track operation order across processes -_QUEUE: Optional[Queue] = None +_QUEUE: Queue | None = None def track_operation(operation: str, rank: int): diff --git a/tests/vllm_test_utils/vllm_test_utils/blame.py b/tests/vllm_test_utils/vllm_test_utils/blame.py index e2cab92ea22b..9746c3964e21 100644 --- a/tests/vllm_test_utils/vllm_test_utils/blame.py +++ b/tests/vllm_test_utils/vllm_test_utils/blame.py @@ -5,8 +5,7 @@ import dataclasses import sys import traceback -from collections.abc import Generator -from typing import Callable +from collections.abc import Callable, Generator @dataclasses.dataclass diff --git a/tests/vllm_test_utils/vllm_test_utils/monitor.py b/tests/vllm_test_utils/vllm_test_utils/monitor.py index e2f1212ed554..ba22bde8795b 100644 --- a/tests/vllm_test_utils/vllm_test_utils/monitor.py +++ b/tests/vllm_test_utils/vllm_test_utils/monitor.py @@ -5,8 +5,8 @@ import dataclasses import sys import traceback -from collections.abc import Generator -from typing import Callable, Generic, TypeVar +from collections.abc import Callable, Generator +from typing import Generic, TypeVar _T = TypeVar("_T") diff --git a/tools/check_init_lazy_imports.py b/tools/check_init_lazy_imports.py index 9255aa17db6a..8b3a0b2a71be 100644 --- a/tools/check_init_lazy_imports.py +++ b/tools/check_init_lazy_imports.py @@ -1,12 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Ensure we perform lazy loading in vllm/__init__.py. -i.e: appears only within the ``if typing.TYPE_CHECKING:`` guard, +i.e: appears only within the `if typing.TYPE_CHECKING:` guard, **except** for a short whitelist. """ -from __future__ import annotations - import ast import pathlib import sys diff --git a/tools/enforce_regex_import.py b/tools/enforce_regex_import.py index 69f43cadc767..a29952e92264 100644 --- a/tools/enforce_regex_import.py +++ b/tools/enforce_regex_import.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import subprocess from pathlib import Path diff --git a/tools/ep_kernels/install_python_libraries.sh b/tools/ep_kernels/install_python_libraries.sh index 5a3d734190c1..c2d8d1ed9e3d 100644 --- a/tools/ep_kernels/install_python_libraries.sh +++ b/tools/ep_kernels/install_python_libraries.sh @@ -119,7 +119,7 @@ popd # build and install deepep, require pytorch installed pushd $WORKSPACE -clone_repo "https://github.com/deepseek-ai/DeepEP" "DeepEP" "setup.py" "e3908bf" +clone_repo "https://github.com/deepseek-ai/DeepEP" "DeepEP" "setup.py" "73b6ea4" cd DeepEP export NVSHMEM_DIR=$WORKSPACE/nvshmem_install $PIP_CMD install --no-build-isolation -vvv -e . diff --git a/tools/install_gdrcopy.sh b/tools/install_gdrcopy.sh index 481723320c63..d8a756879978 100755 --- a/tools/install_gdrcopy.sh +++ b/tools/install_gdrcopy.sh @@ -7,18 +7,15 @@ set -euo pipefail # Requires: curl, apt-get, root privileges if [[ $(id -u) -ne 0 ]]; then echo "Must be run as root" >&2 - exit 1 fi if [[ $# -ne 3 ]]; then echo "Usage: $0 <GDRCOPY_OS_VERSION> <GDRCOPY_CUDA_VERSION> <uuarch(x64|aarch64)>" >&2 exit 1 fi - OS_VER="$1" CUDA_VER="$2" UUARCH_RAW="$3" - # Normalize/validate arch case "${UUARCH_RAW,,}" in aarch64|arm64) diff --git a/tools/install_nixl_from_source_ubuntu.py b/tools/install_nixl_from_source_ubuntu.py new file mode 100644 index 000000000000..c808b01d2e94 --- /dev/null +++ b/tools/install_nixl_from_source_ubuntu.py @@ -0,0 +1,225 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# install_prerequisites.py +import argparse +import glob +import os +import subprocess +import sys + +# --- Configuration --- +WHEELS_CACHE_HOME = os.environ.get("WHEELS_CACHE_HOME", "/tmp/wheels_cache") +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +UCX_DIR = os.path.join("/tmp", "ucx_source") +NIXL_DIR = os.path.join("/tmp", "nixl_source") +UCX_INSTALL_DIR = os.path.join("/tmp", "ucx_install") +UCX_REPO_URL = "https://github.com/openucx/ucx.git" +NIXL_REPO_URL = "https://github.com/ai-dynamo/nixl.git" + + +# --- Helper Functions --- +def run_command(command, cwd=".", env=None): + """Helper function to run a shell command and check for errors.""" + print(f"--> Running command: {' '.join(command)} in '{cwd}'", flush=True) + subprocess.check_call(command, cwd=cwd, env=env) + + +def is_pip_package_installed(package_name): + """Checks if a package is installed via pip without raising an exception.""" + result = subprocess.run( + [sys.executable, "-m", "pip", "show", package_name], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + return result.returncode == 0 + + +def find_nixl_wheel_in_cache(cache_dir): + """Finds a nixl wheel file in the specified cache directory.""" + # The repaired wheel will have a 'manylinux' tag, but this glob still works. + search_pattern = os.path.join(cache_dir, "nixl-*.whl") + wheels = glob.glob(search_pattern) + if wheels: + # Sort to get the most recent/highest version if multiple exist + wheels.sort() + return wheels[-1] + return None + + +def install_system_dependencies(): + """Installs required system packages using apt-get if run as root.""" + if os.geteuid() != 0: + print("\n---", flush=True) + print( + "WARNING: Not running as root. \ + Skipping system dependency installation.", + flush=True, + ) + print( + "Please ensure the listed packages are installed on your system:", + flush=True, + ) + print( + " patchelf build-essential git cmake ninja-build \ + autotools-dev automake meson libtool libtool-bin", + flush=True, + ) + print("---\n", flush=True) + return + + print("--- Running as root. Installing system dependencies... ---", flush=True) + apt_packages = [ + "patchelf", # <-- Add patchelf here + "build-essential", + "git", + "cmake", + "ninja-build", + "autotools-dev", + "automake", + "meson", + "libtool", + "libtool-bin", + ] + run_command(["apt-get", "update"]) + run_command(["apt-get", "install", "-y"] + apt_packages) + print("--- System dependencies installed successfully. ---\n", flush=True) + + +def build_and_install_prerequisites(args): + """Builds UCX and NIXL from source, creating a self-contained wheel.""" + + if not args.force_reinstall and is_pip_package_installed("nixl"): + print("--> NIXL is already installed. Nothing to do.", flush=True) + return + + cached_wheel = find_nixl_wheel_in_cache(WHEELS_CACHE_HOME) + if not args.force_reinstall and cached_wheel: + print( + f"\n--> Found self-contained wheel: \ + {os.path.basename(cached_wheel)}.", + flush=True, + ) + print("--> Installing from cache, skipping all source builds.", flush=True) + install_command = [sys.executable, "-m", "pip", "install", cached_wheel] + run_command(install_command) + print("\n--- Installation from cache complete. ---", flush=True) + return + + print( + "\n--> No installed package or cached wheel found. \ + Starting full build process...", + flush=True, + ) + print("\n--> Installing auditwheel...", flush=True) + run_command([sys.executable, "-m", "pip", "install", "auditwheel"]) + install_system_dependencies() + ucx_install_path = os.path.abspath(UCX_INSTALL_DIR) + print(f"--> Using wheel cache directory: {WHEELS_CACHE_HOME}", flush=True) + os.makedirs(WHEELS_CACHE_HOME, exist_ok=True) + + # -- Step 1: Build UCX from source -- + print("\n[1/3] Configuring and building UCX from source...", flush=True) + if not os.path.exists(UCX_DIR): + run_command(["git", "clone", UCX_REPO_URL, UCX_DIR]) + ucx_source_path = os.path.abspath(UCX_DIR) + run_command(["git", "checkout", "v1.19.x"], cwd=ucx_source_path) + run_command(["./autogen.sh"], cwd=ucx_source_path) + configure_command = [ + "./configure", + f"--prefix={ucx_install_path}", + "--enable-shared", + "--disable-static", + "--disable-doxygen-doc", + "--enable-optimizations", + "--enable-cma", + "--enable-devel-headers", + "--with-verbs", + "--enable-mt", + "--with-ze=no", + ] + run_command(configure_command, cwd=ucx_source_path) + run_command(["make", "-j", str(os.cpu_count() or 1)], cwd=ucx_source_path) + run_command(["make", "install"], cwd=ucx_source_path) + print("--- UCX build and install complete ---", flush=True) + + # -- Step 2: Build NIXL wheel from source -- + print("\n[2/3] Building NIXL wheel from source...", flush=True) + if not os.path.exists(NIXL_DIR): + run_command(["git", "clone", NIXL_REPO_URL, NIXL_DIR]) + + build_env = os.environ.copy() + build_env["PKG_CONFIG_PATH"] = os.path.join(ucx_install_path, "lib", "pkgconfig") + ucx_lib_path = os.path.join(ucx_install_path, "lib") + ucx_plugin_path = os.path.join(ucx_lib_path, "ucx") + existing_ld_path = os.environ.get("LD_LIBRARY_PATH", "") + build_env["LD_LIBRARY_PATH"] = ( + f"{ucx_lib_path}:{ucx_plugin_path}:{existing_ld_path}".strip(":") + ) + print(f"--> Using LD_LIBRARY_PATH: {build_env['LD_LIBRARY_PATH']}", flush=True) + + temp_wheel_dir = os.path.join(ROOT_DIR, "temp_wheelhouse") + run_command( + [ + sys.executable, + "-m", + "pip", + "wheel", + ".", + "--no-deps", + f"--wheel-dir={temp_wheel_dir}", + ], + cwd=os.path.abspath(NIXL_DIR), + env=build_env, + ) + + # -- Step 3: Repair the wheel by copying UCX libraries -- + print("\n[3/3] Repairing NIXL wheel to include UCX libraries...", flush=True) + unrepaired_wheel = find_nixl_wheel_in_cache(temp_wheel_dir) + if not unrepaired_wheel: + raise RuntimeError("Failed to find the NIXL wheel after building it.") + + # We tell auditwheel to ignore the plugin that mesonpy already handled. + auditwheel_command = [ + "auditwheel", + "repair", + "--exclude", + "libplugin_UCX.so", # <-- Exclude because mesonpy already includes it + unrepaired_wheel, + f"--wheel-dir={WHEELS_CACHE_HOME}", + ] + run_command(auditwheel_command, env=build_env) + + # --- CLEANUP --- + # No more temporary files to remove, just the temp wheelhouse + run_command(["rm", "-rf", temp_wheel_dir]) + # --- END CLEANUP --- + + newly_built_wheel = find_nixl_wheel_in_cache(WHEELS_CACHE_HOME) + if not newly_built_wheel: + raise RuntimeError("Failed to find the repaired NIXL wheel.") + + print( + f"--> Successfully built self-contained wheel: \ + {os.path.basename(newly_built_wheel)}. Now installing...", + flush=True, + ) + install_command = [sys.executable, "-m", "pip", "install", newly_built_wheel] + if args.force_reinstall: + install_command.insert(-1, "--force-reinstall") + + run_command(install_command) + print("--- NIXL installation complete ---", flush=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Build and install UCX and NIXL dependencies." + ) + parser.add_argument( + "--force-reinstall", + action="store_true", + help="Force rebuild and reinstall of UCX and NIXL \ + even if they are already installed.", + ) + args = parser.parse_args() + build_and_install_prerequisites(args) diff --git a/tools/pre_commit/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py index bceb894a7a5f..c9256cd91a4e 100644 --- a/tools/pre_commit/check_pickle_imports.py +++ b/tools/pre_commit/check_pickle_imports.py @@ -17,18 +17,18 @@ # add to this list if absolutely necessary and after careful security review. ALLOWED_FILES = { # pickle - "vllm/v1/serial_utils.py", - "vllm/v1/executor/multiproc_executor.py", "vllm/multimodal/hasher.py", "vllm/transformers_utils/config.py", "vllm/model_executor/models/registry.py", - "tests/utils_/test_utils.py", - "tests/tokenization/test_cached_tokenizer.py", + "vllm/compilation/caching.py", "vllm/distributed/utils.py", "vllm/distributed/parallel_state.py", "vllm/distributed/device_communicators/all_reduce_utils.py", "vllm/distributed/device_communicators/shm_broadcast.py", "vllm/distributed/device_communicators/shm_object_storage.py", + "vllm/utils/hashing.py", + "tests/utils_/test_hashing.py", + "tests/tokenization/test_cached_tokenizer.py", "benchmarks/kernels/graph_machete_bench.py", "benchmarks/kernels/benchmark_lora.py", "benchmarks/kernels/benchmark_machete.py", @@ -36,12 +36,13 @@ "benchmarks/cutlass_benchmarks/w8a8_benchmarks.py", "benchmarks/cutlass_benchmarks/sparse_benchmarks.py", # cloudpickle - "vllm/executor/mp_distributed_executor.py", - "vllm/executor/ray_distributed_executor.py", + "vllm/v1/executor/multiproc_executor.py", + "vllm/v1/executor/ray_executor.py", "vllm/entrypoints/llm.py", + "vllm/utils/__init__.py", "tests/utils.py", # pickle and cloudpickle - "vllm/utils/__init__.py", + "vllm/v1/serial_utils.py", } PICKLE_RE = re.compile( diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py index 22ee08535bdd..a3aa54634725 100755 --- a/tools/pre_commit/mypy.py +++ b/tools/pre_commit/mypy.py @@ -20,14 +20,15 @@ import subprocess import sys -from typing import Optional import regex as re FILES = [ "vllm/*.py", "vllm/assets", + "vllm/distributed", "vllm/entrypoints", + "vllm/executor", "vllm/inputs", "vllm/logging_utils", "vllm/multimodal", @@ -43,9 +44,7 @@ "tests", "vllm/attention", "vllm/compilation", - "vllm/distributed", "vllm/engine", - "vllm/executor", "vllm/inputs", "vllm/lora", "vllm/model_executor", @@ -96,8 +95,8 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]: def mypy( targets: list[str], - python_version: Optional[str], - follow_imports: Optional[str], + python_version: str | None, + follow_imports: str | None, file_group: str, ) -> int: """ diff --git a/tools/profiler/visualize_layerwise_profile.py b/tools/profiler/visualize_layerwise_profile.py index cdab004366f9..a049dc0425dd 100644 --- a/tools/profiler/visualize_layerwise_profile.py +++ b/tools/profiler/visualize_layerwise_profile.py @@ -7,7 +7,7 @@ import math import os from pathlib import Path -from typing import Any, Optional +from typing import Any import matplotlib.pyplot as plt import pandas as pd @@ -373,7 +373,7 @@ def plot_trace_df( traces_df: pd.DataFrame, plot_metric: str, plot_title: str, - output: Optional[Path] = None, + output: Path | None = None, ): def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str: phase_df = traces_df.query(f'phase == "{phase}"') diff --git a/tools/validate_config.py b/tools/validate_config.py index d779edabc841..fb6f0e6a9285 100644 --- a/tools/validate_config.py +++ b/tools/validate_config.py @@ -8,6 +8,7 @@ import ast import inspect import sys +from itertools import pairwise import regex as re @@ -20,19 +21,6 @@ def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]: https://davidism.com/mit-license/ """ - def pairwise(iterable): - """ - Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise - - Can be removed when Python 3.9 support is dropped. - """ - iterator = iter(iterable) - a = next(iterator, None) - - for b in iterator: - yield a, b - a = b - out = {} # Consider each pair of nodes. diff --git a/tools/vllm-tpu/build.sh b/tools/vllm-tpu/build.sh new file mode 100644 index 000000000000..fbc91e379df3 --- /dev/null +++ b/tools/vllm-tpu/build.sh @@ -0,0 +1,67 @@ +#!/bin/bash +set -e # Exit immediately if a command exits with a non-zero status. +# Script to build VLLM wheel for TPU with an optional version override. + +SCRIPT_PATH_PARAM="$0" +TOOLS_DIR=$(cd "$(dirname "$SCRIPT_PATH_PARAM")" && pwd) # Absolute path to the script's directory +REPO_ROOT=$(cd "$TOOLS_DIR/../../" && pwd) # Absolute path to the repo root +VLLM_DIR="$REPO_ROOT/" # Path to the vllm sources + +# Ensure we are not running from within the vllm directory if SCRIPT_PATH_PARAM is relative like "." +if [ "$TOOLS_DIR" = "$VLLM_DIR" ]; then + echo "Error: This script should not be run from the vllm directory directly if using relative paths." + echo "Place it in a subdirectory like 'tools/vllm-tpu' and run it from the repository root or via its full path." + exit 1 +fi + +# Optional version argument +if [ -n "$1" ]; then + USER_VERSION="$1" + export VLLM_VERSION_OVERRIDE="$USER_VERSION" + echo "User defined version: $USER_VERSION" +else + echo "No version override supplied. Using default version from source." +fi + +PYPROJECT_FILE="$VLLM_DIR/pyproject.toml" + +# Backup and update the project name. +if ! grep -q "name = \"vllm-tpu\"" "$PYPROJECT_FILE"; then + echo "Patching pyproject.toml project name to vllm-tpu..." + cp "$PYPROJECT_FILE" "${PYPROJECT_FILE}.bak" + sed -i '0,/^name = "vllm"/s//name = "vllm-tpu"/' "$PYPROJECT_FILE" + PATCHED=true +else + PATCHED=false +fi + +# Navigate to the vllm directory +cd "$VLLM_DIR" + +# Cleanup function to be called on exit or error +cleanup() { + echo "Cleaning up..." + if [ "$PATCHED" = true ]; then + echo "Restoring original pyproject.toml..." + cp "${PYPROJECT_FILE}.bak" "$PYPROJECT_FILE" + rm -f "${PYPROJECT_FILE}.bak" + fi +} +trap cleanup EXIT HUP INT QUIT PIPE TERM # Register cleanup function to run on script exit and various signals + +echo "Updating pyproject.toml completed. Proceeding with build..." + +echo "Building wheel for TPU..." +rm -rf dist/ +mkdir -p dist/ + +# User confirmed to use 'python -m build' directly +if ! VLLM_TARGET_DEVICE=tpu python -m build; then + echo "Error: Python build command failed. Check if 'python -m build' works and the 'build' module is installed." + exit 1 +fi + +trap - EXIT HUP INT QUIT PIPE TERM +cleanup + +exit 0 \ No newline at end of file diff --git a/vllm/__init__.py b/vllm/__init__.py index b9c868de6886..19b2cdc673c4 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -21,7 +21,7 @@ "AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine", "LLMEngine": ".engine.llm_engine:LLMEngine", "LLM": ".entrypoints.llm:LLM", - "initialize_ray_cluster": ".executor.ray_utils:initialize_ray_cluster", + "initialize_ray_cluster": ".v1.executor.ray_utils:initialize_ray_cluster", "PromptType": ".inputs:PromptType", "TextPrompt": ".inputs:TextPrompt", "TokensPrompt": ".inputs:TokensPrompt", @@ -45,7 +45,6 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM - from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import ( @@ -62,6 +61,7 @@ ) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams + from vllm.v1.executor.ray_utils import initialize_ray_cluster from ._bc_linter import bc_linter_include, bc_linter_skip else: diff --git a/vllm/_bc_linter.py b/vllm/_bc_linter.py index af68396af0b5..2929a8bce85a 100644 --- a/vllm/_bc_linter.py +++ b/vllm/_bc_linter.py @@ -1,9 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # vllm/_bc_linter.py -from __future__ import annotations - -from typing import Any, Callable, TypeVar, overload +from collections.abc import Callable +from typing import Any, TypeVar, overload T = TypeVar("T") diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9fa346cca56d..9110b0573fc9 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Literal import torch @@ -12,8 +12,7 @@ logger = init_logger(__name__) -current_platform.import_core_kernels() -supports_moe_ops = current_platform.try_import_moe_kernels() +current_platform.import_kernels() if TYPE_CHECKING: @@ -38,7 +37,7 @@ def paged_attention_v1( seq_lens: torch.Tensor, block_size: int, max_seq_len: int, - alibi_slopes: Optional[torch.Tensor], + alibi_slopes: torch.Tensor | None, kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, @@ -85,7 +84,7 @@ def paged_attention_v2( seq_lens: torch.Tensor, block_size: int, max_seq_len: int, - alibi_slopes: Optional[torch.Tensor], + alibi_slopes: torch.Tensor | None, kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, @@ -133,14 +132,14 @@ def paged_attention_rocm( scale: float, block_tables: torch.Tensor, seq_lens: torch.Tensor, - query_start_loc: Optional[torch.Tensor], + query_start_loc: torch.Tensor | None, block_size: int, max_seq_len: int, - alibi_slopes: Optional[torch.Tensor], + alibi_slopes: torch.Tensor | None, kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, - fp8_out_scale: Optional[torch.Tensor] = None, + fp8_out_scale: torch.Tensor | None = None, mfma_type: str = "fp8" if envs.VLLM_ROCM_FP8_MFMA_PAGE_ATTN else "f16", ) -> None: torch.ops._rocm_C.paged_attention( @@ -187,7 +186,7 @@ def merge_attn_states( prefix_lse: torch.Tensor, suffix_output: torch.Tensor, suffix_lse: torch.Tensor, - output_lse: Optional[torch.Tensor] = None, + output_lse: torch.Tensor | None = None, ) -> None: torch.ops._C.merge_attn_states( output, output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse @@ -315,7 +314,7 @@ def convert_vertical_slash_indexes_mergehead( def rotary_embedding( positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor], + key: torch.Tensor | None, head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, @@ -340,18 +339,6 @@ def fused_add_rms_norm( torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) -def poly_norm( - out: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - epsilon: float, -) -> None: - # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input - input_contiguous = input.contiguous() - torch.ops._C.poly_norm(out, input_contiguous, weight, bias, epsilon) - - def apply_repetition_penalties_torch( logits: torch.Tensor, prompt_mask: torch.Tensor, @@ -409,8 +396,8 @@ def rms_norm_dynamic_per_token_quant( weight: torch.Tensor, epsilon: float, quant_dtype: torch.dtype, - scale_ub: Optional[torch.Tensor] = None, - residual: Optional[torch.Tensor] = None, + scale_ub: torch.Tensor | None = None, + residual: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: output = torch.empty_like(input, dtype=quant_dtype) scales = torch.empty( @@ -464,10 +451,18 @@ def gptq_gemm( b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, use_exllama: bool, + use_v2_format: bool, bit: int, ) -> torch.Tensor: return torch.ops._C.gptq_gemm( - a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama, bit + a, + b_q_weight, + b_gptq_qzeros, + b_gptq_scales, + b_g_idx, + use_exllama, + use_v2_format, + bit, ) @@ -481,6 +476,7 @@ def _gptq_gemm_fake( b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, use_exllama: bool, + use_v2_format: bool, bit: int, ) -> torch.Tensor: return torch.empty( @@ -528,14 +524,14 @@ def _gptq_marlin_24_gemm_fake( @register_fake("_C::gptq_marlin_gemm") def _gptq_marlin_gemm_fake( a: torch.Tensor, - c: Optional[torch.Tensor], + c: torch.Tensor | None, b_q_weight: torch.Tensor, - b_bias: Optional[torch.Tensor], + b_bias: torch.Tensor | None, b_scales: torch.Tensor, - global_scale: Optional[torch.Tensor], - b_zeros: Optional[torch.Tensor], - g_idx: Optional[torch.Tensor], - perm: Optional[torch.Tensor], + global_scale: torch.Tensor | None, + b_zeros: torch.Tensor | None, + g_idx: torch.Tensor | None, + perm: torch.Tensor | None, workspace: torch.Tensor, b_q_type_id: int, size_m: torch.SymInt, @@ -583,13 +579,13 @@ def machete_mm_fake( # b_q Should be the tensor returned by machete_prepack_B b_q: torch.Tensor, b_type: ScalarType, - out_type: Optional[torch.dtype] = None, - b_group_scales: Optional[torch.Tensor] = None, - b_group_zeros: Optional[torch.Tensor] = None, - b_group_size: Optional[int] = None, - b_channel_scales: Optional[torch.Tensor] = None, - a_token_scales: Optional[torch.Tensor] = None, - schedule: Optional[str] = None, + out_type: torch.dtype | None = None, + b_group_scales: torch.Tensor | None = None, + b_group_zeros: torch.Tensor | None = None, + b_group_size: int | None = None, + b_channel_scales: torch.Tensor | None = None, + a_token_scales: torch.Tensor | None = None, + schedule: str | None = None, ) -> torch.Tensor: m = a.size(0) n = b_q.size(1) @@ -600,7 +596,7 @@ def machete_prepack_B_fake( b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, - group_scales_type: Optional[torch.dtype], + group_scales_type: torch.dtype | None, ) -> torch.Tensor: return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) @@ -613,8 +609,8 @@ def cutlass_w4a8_mm_fake( b_group_size: int, b_channel_scales: torch.Tensor, a_token_scales: torch.Tensor, - out_type: Optional[torch.dtype] = None, - maybe_schedule: Optional[str] = None, + out_type: torch.dtype | None = None, + maybe_schedule: str | None = None, ) -> torch.Tensor: m = a.size(0) n = b_q.size(1) @@ -637,7 +633,7 @@ def _allspark_w8a16_gemm_fake( a: torch.Tensor, b_qweight: torch.Tensor, b_scales: torch.Tensor, - b_qzeros: Optional[torch.Tensor], + b_qzeros: torch.Tensor | None, n: torch.SymInt, group_size: torch.SymInt, sm_count: torch.SymInt, @@ -658,7 +654,7 @@ def _ggml_dequantize_fake( quant_type: int, m: torch.SymInt, n: torch.SymInt, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype | None = None, ) -> torch.Tensor: return torch.empty((m, n), dtype=torch.float16, device=W.device) @@ -761,7 +757,7 @@ def cutlass_scaled_mm( scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: """ `cutlass_scaled_mm` implements a fused version of @@ -813,8 +809,8 @@ def cutlass_scaled_mm_azp( scale_b: torch.Tensor, out_dtype: torch.dtype, azp_adj: torch.Tensor, - azp: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, + azp: torch.Tensor | None = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: """ :param azp_adj: In the per-tensor case, this should include the azp. @@ -891,7 +887,7 @@ def cutlass_scaled_sparse_mm( scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: """ Performs a scaled sparse matrix multiplication using Cutlass. @@ -941,7 +937,7 @@ def get_cutlass_moe_mm_data( num_experts: int, n: int, k: int, - blockscale_offsets: Optional[torch.Tensor] = None, + blockscale_offsets: torch.Tensor | None = None, ): """ Prepare data necessary to perform CUTLASS grouped matrix multiplications @@ -987,7 +983,7 @@ def get_cutlass_moe_mm_problem_sizes( num_experts: int, n: int, k: int, - blockscale_offsets: Optional[torch.Tensor] = None, + blockscale_offsets: torch.Tensor | None = None, ): """ Compute only the per-expert problem sizes needed by the two grouped matrix @@ -1196,14 +1192,14 @@ def awq_marlin_moe_repack( def gptq_marlin_gemm( a: torch.Tensor, - c: Optional[torch.Tensor], + c: torch.Tensor | None, b_q_weight: torch.Tensor, - b_bias: Optional[torch.Tensor], + b_bias: torch.Tensor | None, b_scales: torch.Tensor, - global_scale: Optional[torch.Tensor], - b_zeros: Optional[torch.Tensor], - g_idx: Optional[torch.Tensor], - perm: Optional[torch.Tensor], + global_scale: torch.Tensor | None, + b_zeros: torch.Tensor | None, + g_idx: torch.Tensor | None, + perm: torch.Tensor | None, workspace: torch.Tensor, b_q_type: ScalarType, size_m: int, @@ -1240,11 +1236,11 @@ def gptq_marlin_gemm( def machete_supported_schedules( a_type: torch.dtype, b_type: ScalarType, - group_scales_type: Optional[torch.dtype], - group_zeros_type: Optional[torch.dtype] = None, - channel_scales_type: Optional[torch.dtype] = None, - token_scales_type: Optional[torch.dtype] = None, - out_type: Optional[torch.dtype] = None, + group_scales_type: torch.dtype | None, + group_zeros_type: torch.dtype | None = None, + channel_scales_type: torch.dtype | None = None, + token_scales_type: torch.dtype | None = None, + out_type: torch.dtype | None = None, ) -> list[str]: return torch.ops._C.machete_supported_schedules( a_type, @@ -1262,13 +1258,13 @@ def machete_mm( # b_q Should be the tensor returned by machete_prepack_B b_q: torch.Tensor, b_type: ScalarType, - out_type: Optional[torch.dtype] = None, - b_group_scales: Optional[torch.Tensor] = None, - b_group_zeros: Optional[torch.Tensor] = None, - b_group_size: Optional[int] = None, - b_channel_scales: Optional[torch.Tensor] = None, - a_token_scales: Optional[torch.Tensor] = None, - schedule: Optional[str] = None, + out_type: torch.dtype | None = None, + b_group_scales: torch.Tensor | None = None, + b_group_zeros: torch.Tensor | None = None, + b_group_size: int | None = None, + b_channel_scales: torch.Tensor | None = None, + a_token_scales: torch.Tensor | None = None, + schedule: str | None = None, ) -> torch.Tensor: return torch.ops._C.machete_mm( a, @@ -1288,7 +1284,7 @@ def machete_prepack_B( b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, - group_scales_type: Optional[torch.dtype], + group_scales_type: torch.dtype | None, ) -> torch.Tensor: return torch.ops._C.machete_prepack_B( b_q_weight, a_type, b_type.id, group_scales_type @@ -1304,8 +1300,8 @@ def cutlass_w4a8_mm( b_group_size: int, b_channel_scales: torch.Tensor, a_token_scales: torch.Tensor, - out_type: Optional[torch.dtype] = None, - maybe_schedule: Optional[str] = None, + out_type: torch.dtype | None = None, + maybe_schedule: str | None = None, ) -> torch.Tensor: return torch.ops._C.cutlass_w4a8_mm( a, @@ -1385,7 +1381,7 @@ def scaled_fp4_quant( rounded_m = round_up(m, 128) scale_n = n // block_size rounded_n = round_up(scale_n, 4) - output_scale = torch.empty( + output_scale = torch.zeros( (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 ) @@ -1459,11 +1455,11 @@ def scaled_fp4_experts_quant( # fp8 def scaled_fp8_quant( input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, + scale: torch.Tensor | None = None, + num_token_padding: int | None = None, + scale_ub: torch.Tensor | None = None, use_per_token_if_dynamic: bool = False, - output: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. @@ -1490,7 +1486,7 @@ def scaled_fp8_quant( """ # This code assumes batch_dim and num_tokens are flattened assert input.ndim == 2 - shape: Union[tuple[int, int], torch.Size] = input.shape + shape: tuple[int, int] | torch.Size = input.shape # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz out_dtype: torch.dtype = current_platform.fp8_dtype() if num_token_padding: @@ -1508,7 +1504,7 @@ def scaled_fp8_quant( output, input, scale, scale_ub ) else: - scale = torch.empty(1, device=input.device, dtype=torch.float32) + scale = torch.empty((1, 1), device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: assert scale.numel() == 1, f"{scale.shape}" @@ -1521,7 +1517,7 @@ def scaled_fp8_quant( def allspark_repack_weight( qweight: torch.Tensor, scale: torch.Tensor, - zero_point: Optional[torch.Tensor] = None, + zero_point: torch.Tensor | None = None, has_zp: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -1577,7 +1573,7 @@ def allspark_w8a16_gemm( a: torch.Tensor, b_qweight: torch.Tensor, b_scales: torch.Tensor, - b_qzeros: Optional[torch.Tensor], + b_qzeros: torch.Tensor | None, n: int, group_size: int, sm_count: int, @@ -1604,10 +1600,10 @@ def allspark_w8a16_gemm( # int8 def scaled_int8_quant( input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, + scale: torch.Tensor | None = None, + azp: torch.Tensor | None = None, symmetric: bool = True, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: """ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. @@ -1644,7 +1640,7 @@ def scaled_int8_quant( # gguf def ggml_dequantize( - W: torch.Tensor, quant_type: int, m: int, n: int, dtype: Optional[torch.dtype] + W: torch.Tensor, quant_type: int, m: int, n: int, dtype: torch.dtype | None ) -> torch.Tensor: return torch.ops._C.ggml_dequantize(W, quant_type, m, n, dtype) @@ -1714,13 +1710,13 @@ def selective_scan_fwd( A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - D_: Optional[torch.Tensor], - z_: Optional[torch.Tensor], - delta_bias_: Optional[torch.Tensor], + D_: torch.Tensor | None, + z_: torch.Tensor | None, + delta_bias_: torch.Tensor | None, delta_softplus: bool, - query_start_loc: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], + query_start_loc: torch.Tensor | None, + cache_indices: torch.Tensor | None, + has_initial_state: torch.Tensor | None, ssm_states: torch.Tensor, pad_slot_id: int, ): @@ -1790,13 +1786,57 @@ def moe_align_block_size( ) +def batched_moe_align_block_size( + max_tokens_per_batch: int, + block_size: int, + expert_num_tokens: torch.Tensor, + sorted_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + torch.ops._moe_C.batched_moe_align_block_size( + max_tokens_per_batch, + block_size, + expert_num_tokens, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + + +def moe_lora_align_block_size( + topk_ids: torch.Tensor, + token_lora_mapping: torch.Tensor, + num_experts: int, + block_size: int, + max_loras: int, + max_num_tokens_padded: int, + max_num_m_blocks: int, + sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + torch.ops._moe_C.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_token_ids, + experts_ids, + num_tokens_post_pad, + ) + + def moe_wna16_gemm( input: torch.Tensor, output: torch.Tensor, b_qweight: torch.Tensor, b_scales: torch.Tensor, - b_qzeros: Optional[torch.Tensor], - topk_weights: Optional[torch.Tensor], + b_qzeros: torch.Tensor | None, + topk_weights: torch.Tensor | None, sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor, @@ -1833,9 +1873,10 @@ def topk_softmax( topk_ids: torch.Tensor, token_expert_indices: torch.Tensor, gating_output: torch.Tensor, + renormalize: bool = False, ) -> None: torch.ops._moe_C.topk_softmax( - topk_weights, topk_ids, token_expert_indices, gating_output + topk_weights, topk_ids, token_expert_indices, gating_output, renormalize ) @@ -1865,14 +1906,14 @@ def grouped_topk( def moe_wna16_marlin_gemm( input: torch.Tensor, - output: Optional[torch.Tensor], + output: torch.Tensor | None, b_qweight: torch.Tensor, - b_bias: Optional[torch.Tensor], + b_bias: torch.Tensor | None, b_scales: torch.Tensor, - global_scale: Optional[torch.Tensor], - b_qzeros: Optional[torch.Tensor], - g_idx: Optional[torch.Tensor], - perm: Optional[torch.Tensor], + global_scale: torch.Tensor | None, + b_qzeros: torch.Tensor | None, + g_idx: torch.Tensor | None, + perm: torch.Tensor | None, workspace: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, @@ -1921,7 +1962,7 @@ def moe_wna16_marlin_gemm( ) -if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): +if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): @register_fake("_moe_C::marlin_gemm_moe") def marlin_gemm_moe_fake( @@ -1951,12 +1992,12 @@ def marlin_gemm_moe_fake( @register_fake("_moe_C::moe_wna16_marlin_gemm") def moe_wna16_marlin_gemm_fake( input: torch.Tensor, - output: Optional[torch.Tensor], + output: torch.Tensor | None, b_qweight: torch.Tensor, b_scales: torch.Tensor, - b_qzeros: Optional[torch.Tensor], - g_idx: Optional[torch.Tensor], - perm: Optional[torch.Tensor], + b_qzeros: torch.Tensor | None, + g_idx: torch.Tensor | None, + perm: torch.Tensor | None, workspace: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, @@ -2069,7 +2110,7 @@ def gather_and_maybe_dequant_cache( batch_size: int, kv_cache_dtype: str, scale: torch.Tensor, - seq_starts: Optional[torch.Tensor] = None, + seq_starts: torch.Tensor | None = None, ) -> None: torch.ops._C_cache_ops.gather_and_maybe_dequant_cache( src_cache, @@ -2089,7 +2130,7 @@ def cp_gather_cache( block_table: torch.Tensor, cu_seq_lens: torch.Tensor, batch_size: int, - seq_starts: Optional[torch.Tensor] = None, + seq_starts: torch.Tensor | None = None, ) -> None: torch.ops._C_cache_ops.cp_gather_cache( src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts @@ -2188,9 +2229,7 @@ def free_shared_buffer(ptr: int) -> None: # quick all reduce -def init_custom_qr( - rank: int, world_size: int, qr_max_size: Optional[int] = None -) -> int: +def init_custom_qr(rank: int, world_size: int, qr_max_size: int | None = None) -> int: return torch.ops._C_custom_ar.init_custom_qr(rank, world_size, qr_max_size) @@ -2248,7 +2287,7 @@ def flash_mla_with_kvcache( head_dim_v: int, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, - softmax_scale: Optional[float] = None, + softmax_scale: float | None = None, causal: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -2325,7 +2364,7 @@ def sm100_cutlass_mla_get_workspace_size( def weight_packed_linear_fake( mat1: torch.Tensor, mat2: torch.Tensor, - bias: Optional[torch.Tensor], + bias: torch.Tensor | None, is_vnni: bool, ) -> torch.Tensor: return torch.empty( @@ -2345,11 +2384,11 @@ def fused_experts_cpu_fake( inplace: bool, use_int8_w8a8: bool, use_fp8_w8a16: bool, - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - block_size: Optional[list[int]], - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + w1_scale: torch.Tensor | None, + w2_scale: torch.Tensor | None, + block_size: list[int] | None, + a1_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, is_vnni: bool, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -2362,7 +2401,7 @@ def int8_scaled_mm_with_quant_fake( mat1: torch.Tensor, mat2: torch.Tensor, scales2: torch.Tensor, - bias: Optional[torch.Tensor], + bias: torch.Tensor | None, out_dtype: torch.dtype, is_vnni: bool, ) -> torch.Tensor: @@ -2373,7 +2412,7 @@ def int8_scaled_mm_with_quant_fake( class CPUDNNLGEMMHandler: def __init__(self) -> None: - self.handler: Optional[int] = None + self.handler: int | None = None self.n = -1 self.k = -1 @@ -2404,7 +2443,7 @@ def create_onednn_mm( def onednn_mm( dnnl_handler: CPUDNNLGEMMHandler, x: torch.Tensor, - bias: Optional[torch.Tensor], + bias: torch.Tensor | None, ) -> torch.Tensor: output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype) torch.ops._C.onednn_mm( @@ -2432,8 +2471,8 @@ def create_onednn_scaled_mm( def onednn_scaled_int8_quant( input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, + scale: torch.Tensor | None = None, + azp: torch.Tensor | None = None, symmetric: bool = True, ): """ @@ -2472,10 +2511,10 @@ def onednn_scaled_mm( dnnl_handler: CPUDNNLGEMMHandler, x: torch.Tensor, output: torch.Tensor, - input_scale: Optional[torch.Tensor], - input_zp: Optional[torch.Tensor], - input_zp_adj: Optional[torch.Tensor], - bias: Optional[torch.Tensor], + input_scale: torch.Tensor | None, + input_zp: torch.Tensor | None, + input_zp_adj: torch.Tensor | None, + bias: torch.Tensor | None, ) -> torch.Tensor: torch.ops._C.onednn_scaled_mm( output, x, input_scale, input_zp, input_zp_adj, bias, dnnl_handler.handler @@ -2484,6 +2523,144 @@ def onednn_scaled_mm( return output +if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"): + + @register_fake("_qutlass_C::matmul_mxf4_bf16_tn") + def _fake_matmul_mxf4_bf16_tn( + a: torch.Tensor, + b: torch.Tensor, + a_sf: torch.Tensor, + b_sf: torch.Tensor, + alpha: torch.Tensor, + ): + return a.new_empty(*a.shape[:-1], b.shape[0], dtype=torch.bfloat16) + + +def matmul_mxf4_bf16_tn( + a: torch.Tensor, + b: torch.Tensor, + a_sf: torch.Tensor, + b_sf: torch.Tensor, + alpha: torch.Tensor, +) -> torch.Tensor: + return torch.ops._qutlass_C.matmul_mxf4_bf16_tn(a, b, a_sf, b_sf, alpha) + + +if hasattr(torch.ops._qutlass_C, "matmul_ada_mxf4_bf16_tn"): + + @register_fake("_qutlass_C::matmul_ada_mxf4_bf16_tn") + def _fake_matmul_ada_mxf4_bf16_tn( + a: torch.Tensor, + b: torch.Tensor, + a_sf: torch.Tensor, + b_sf: torch.Tensor, + alpha: torch.Tensor, + ): + return a.new_empty(*a.shape[:-1], b.shape[0], dtype=torch.bfloat16) + + +def matmul_ada_mxf4_bf16_tn( + a: torch.Tensor, + b: torch.Tensor, + a_sf: torch.Tensor, + b_sf: torch.Tensor, + alpha: torch.Tensor, +) -> torch.Tensor: + return torch.ops._qutlass_C.matmul_ada_mxf4_bf16_tn(a, b, a_sf, b_sf, alpha) + + +def ceil_div(a, b): + return (a + b - 1) // b + + +if hasattr(torch.ops._qutlass_C, "fusedQuantizeMxQuest"): + + @register_fake("_qutlass_C::fusedQuantizeMxQuest") + def _fake_fused_quantize_mx_quest( + a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor + ): + return xh_e2m1, xh_e8m0 + + +if hasattr(torch.ops._qutlass_C, "fusedQuantizeMxAbsMax"): + + @register_fake("_qutlass_C::fusedQuantizeMxAbsMax") + def _fake_fused_quantize_mx_absmax( + a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor + ): + return xh_e2m1, xh_e8m0 + + +def fusedQuantizeMx( + a: torch.Tensor, b: torch.Tensor, *, method: Literal["quest", "abs_max"] = "quest" +) -> tuple[torch.Tensor, torch.Tensor]: + if a.dim() == 0: + raise ValueError("`a` must have at least 1 dimension.") + if a.size(-1) % 32 != 0: + raise ValueError(f"last dim of `a` must be divisible by 32, got {a.size(-1)}.") + if b.device != a.device: + raise ValueError("`a` and `b` must be on the same device.") + + xh_e2m1 = torch.empty( + *a.shape[:-1], a.size(-1) // 2, dtype=torch.uint8, device=a.device + ) + + rows, cols = a.numel() // a.size(-1), a.size(-1) // 32 + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + xh_e8m0 = torch.empty( + padded_rows, padded_cols, dtype=torch.float8_e8m0fnu, device=a.device + ) + + if not hasattr(torch.ops, "_qutlass_C"): + raise RuntimeError( + "The `_qutlass_C` extension is not loaded. " + "Make sure your custom op library is imported before calling fusedQuantizeMx." + ) + + if method == "quest": + return torch.ops._qutlass_C.fusedQuantizeMxQuest(a, b, xh_e2m1, xh_e8m0) + elif method == "abs_max": + return torch.ops._qutlass_C.fusedQuantizeMxAbsMax(a, b, xh_e2m1, xh_e8m0) + else: + raise ValueError(f"invalid method {method!r}, must be 'quest' or 'abs_max'") + + +if hasattr(torch.ops._qutlass_C, "fusedQuantizeNv"): + + @register_fake("_qutlass_C::fusedQuantizeNv") + def _fake_fused_quantize_nv( + a: torch.Tensor, + b: torch.Tensor, + xh_e2m1: torch.Tensor, + xh_e4m3: torch.Tensor, + global_scale: torch.Tensor, + ): + return xh_e2m1, xh_e4m3 + + +def fusedQuantizeNv( + a: torch.Tensor, b: torch.Tensor, global_scale: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + xh_e2m1 = torch.empty( + *a.shape[:-1], a.size(-1) // 2, dtype=torch.uint8, device=a.device + ) + + rows, cols = a.numel() // a.size(-1), a.size(-1) // 16 + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + xh_e4m3 = torch.empty( + padded_rows, padded_cols, dtype=torch.float8_e4m3fn, device=a.device + ) + + return torch.ops._qutlass_C.fusedQuantizeNv(a, b, xh_e2m1, xh_e4m3, global_scale) + + def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor: """ Perform Hadamard transforms using [Hadacore](https://arxiv.org/abs/2412.08832) diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 1f458f940a28..e773e1d13f0b 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union import torch @@ -65,7 +64,7 @@ def paged_attention_v1( context_lens: torch.Tensor, block_size: int, max_context_len: int, - alibi_slopes: Optional[torch.Tensor], + alibi_slopes: torch.Tensor | None, kv_cache_dtype: str, k_scale: float, v_scale: float, @@ -107,7 +106,7 @@ def paged_attention_v2( context_lens: torch.Tensor, block_size: int, max_context_len: int, - alibi_slopes: Optional[torch.Tensor], + alibi_slopes: torch.Tensor | None, kv_cache_dtype: str, k_scale: float, v_scale: float, @@ -174,7 +173,7 @@ def varlen_attention( out: torch.Tensor, seqlen_q: torch.Tensor, seqlen_k: torch.Tensor, - alibi_slopes: Optional[torch.Tensor], + alibi_slopes: torch.Tensor | None, max_seqlen_q: int, max_seqlen_k: int, pdropout: float, @@ -254,8 +253,8 @@ def reshape_and_cache_flash( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - k_scale: Optional[torch.Tensor] = None, - v_scale: Optional[torch.Tensor] = None, + k_scale: torch.Tensor | None = None, + v_scale: torch.Tensor | None = None, k_scale_float: float = 1.0, v_scale_float: float = 1.0, ) -> None: @@ -283,10 +282,10 @@ def flash_attn_varlen_func( softmax_scale: float, causal: bool, block_table: torch.Tensor, - alibi_slopes: Optional[torch.Tensor], - window_size: Optional[list[int]] = None, - softcap: Optional[float] = 0.0, - cu_seqlens_k: Optional[torch.Tensor] = None, + alibi_slopes: torch.Tensor | None, + window_size: list[int] | None = None, + softcap: float | None = 0.0, + cu_seqlens_k: torch.Tensor | None = None, # The following parameters are not used in ipex kernel currently, # we keep API compatible to CUDA's. scheduler_metadata=None, @@ -295,7 +294,7 @@ def flash_attn_varlen_func( k_descale=None, v_descale=None, num_splits=0, - s_aux: Optional[torch.Tensor] = None, + s_aux: torch.Tensor | None = None, ): if cu_seqlens_k is None: # cu_seqlens_k is not used in ipex kernel. @@ -344,10 +343,10 @@ def get_scheduler_metadata( cache_seqlens: torch.Tensor, qkv_dtype=torch.bfloat16, headdim_v=None, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k_new: Optional[torch.Tensor] = None, - cache_leftpad: Optional[torch.Tensor] = None, - page_size: Optional[int] = None, + cu_seqlens_q: torch.Tensor | None = None, + cu_seqlens_k_new: torch.Tensor | None = None, + cache_leftpad: torch.Tensor | None = None, + page_size: int | None = None, max_seqlen_k_new=0, causal=False, window_size=(-1, -1), # -1 means infinite context window @@ -382,11 +381,11 @@ def swap_blocks( @staticmethod def scaled_fp8_quant( input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, + scale: torch.Tensor | None = None, + num_token_padding: int | None = None, + scale_ub: torch.Tensor | None = None, use_per_token_if_dynamic: bool = False, - output: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. @@ -414,7 +413,7 @@ def scaled_fp8_quant( """ # This code assumes batch_dim and num_tokens are flattened assert input.ndim == 2 - shape: Union[tuple[int, int], torch.Size] = input.shape + shape: tuple[int, int] | torch.Size = input.shape out_dtype: torch.dtype = current_platform.fp8_dtype() if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index 61c2dbf55fe3..b527ffcf9b18 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -8,7 +8,7 @@ import numpy.typing as npt -from vllm.utils import PlaceholderModule +from vllm.utils.import_utils import PlaceholderModule from .base import VLLM_S3_BUCKET_URL, get_vllm_public_assets diff --git a/vllm/assets/base.py b/vllm/assets/base.py index 409bfc18ff8c..5ca9de4076ad 100644 --- a/vllm/assets/base.py +++ b/vllm/assets/base.py @@ -3,7 +3,6 @@ from functools import lru_cache from pathlib import Path -from typing import Optional import vllm.envs as envs from vllm.connections import global_http_connection @@ -20,9 +19,9 @@ def get_cache_dir() -> Path: @lru_cache -def get_vllm_public_assets(filename: str, s3_prefix: Optional[str] = None) -> Path: +def get_vllm_public_assets(filename: str, s3_prefix: str | None = None) -> Path: """ - Download an asset file from ``s3://vllm-public-assets`` + Download an asset file from `s3://vllm-public-assets` and return the path to the downloaded file. """ asset_directory = get_cache_dir() / "vllm_public_assets" diff --git a/vllm/assets/video.py b/vllm/assets/video.py index 6b2ca8f867e0..8818b5997004 100644 --- a/vllm/assets/video.py +++ b/vllm/assets/video.py @@ -3,15 +3,14 @@ from dataclasses import dataclass from functools import lru_cache -from typing import Any, ClassVar, Literal, Optional +from typing import Any, ClassVar, Literal -import cv2 import numpy as np import numpy.typing as npt from huggingface_hub import hf_hub_download from PIL import Image -from vllm.utils import PlaceholderModule +from vllm.utils.import_utils import PlaceholderModule from .base import get_cache_dir @@ -43,6 +42,8 @@ def download_video_asset(filename: str) -> str: def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray: + import cv2 + cap = cv2.VideoCapture(path) if not cap.isOpened(): raise ValueError(f"Could not open video file {path}") @@ -78,6 +79,8 @@ def video_to_pil_images_list(path: str, num_frames: int = -1) -> list[Image.Imag def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]: + import cv2 + cap = cv2.VideoCapture(path) if not cap.isOpened(): raise ValueError(f"Could not open video file {path}") @@ -137,7 +140,7 @@ def metadata(self) -> dict[str, Any]: ret = video_get_metadata(self.video_path, self.num_frames) return ret - def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray: + def get_audio(self, sampling_rate: float | None = None) -> npt.NDArray: """ Read audio data from the video asset, used in Qwen2.5-Omni examples. diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index bb2f36271103..e9c6a278a941 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Generic, Optional, Protocol, TypeVar +from typing import Generic, Protocol, TypeVar import torch +from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey @@ -25,6 +26,13 @@ class AttentionType: """Attention between dec. Q and enc. K/V for encoder-decoder.""" +class MultipleOf: + base: int + + def __init__(self, base: int): + self.base = base + + class AttentionBackend(ABC): """Abstract class for attention backends.""" @@ -33,14 +41,6 @@ class AttentionBackend(ABC): # makes sure the output tensor is allocated inside the cudagraph. accept_output_buffer: bool = False - # Whether this backend supports receiving pre-quantized query input. - # If True, the attention layer will handle query quantization instead - # of the backend, allowing torch.compile to fuse quantization with - # previous operations. - # Needs to be worked through for all backends - # https://github.com/vllm-project/vllm/issues/25584 - supports_quant_query_input: bool = False - @staticmethod @abstractmethod def get_name() -> str: @@ -56,6 +56,10 @@ def get_impl_cls() -> type["AttentionImpl"]: def get_metadata_cls() -> type["AttentionMetadata"]: raise NotImplementedError + @classmethod + def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: + return cls.get_impl_cls().get_supported_kernel_block_size() + @classmethod def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": return cls.get_metadata_cls()(*args, **kwargs) @@ -146,16 +150,21 @@ def __init__( num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[list[float]] = None, - sliding_window: Optional[int] = None, + num_kv_heads: int | None = None, + alibi_slopes: list[float] | None = None, + sliding_window: int | None = None, kv_cache_dtype: str = "auto", - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, + kv_sharing_target_layer_name: str | None = None, ) -> None: raise NotImplementedError + @staticmethod + def get_supported_kernel_block_size() -> list[int | MultipleOf]: + # TODO: implement this function for all backends. + return [MultipleOf(1)] + @abstractmethod def forward( self, @@ -165,9 +174,9 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: T, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: raise NotImplementedError @@ -182,8 +191,52 @@ def fused_output_quant_supported(self, quant_key: QuantKey): """ return False + def supports_quant_query_input(self) -> bool: + """ + Check if this attention implementation supports pre-quantized query input. + + When True, the attention layer will quantize queries before passing them + to this backend, allowing torch.compile to fuse the quantization with + previous operations. This is typically supported when using FP8 KV cache + with compatible attention kernels (e.g., TRT-LLM). + TODO add support to more backends: + https://github.com/vllm-project/vllm/issues/25584 + + Returns: + bool: True if the implementation can accept pre-quantized queries. + """ + return False + + def process_weights_after_loading(self, act_dtype: torch.dtype): + pass + class MLAAttentionImpl(AttentionImpl[T], Generic[T]): + @abstractmethod + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None, + attn_type: str, + kv_sharing_target_layer_name: str | None, + # MLA Specific Arguments + q_lora_rank: int | None, + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + kv_b_proj: ColumnParallelLinear, + indexer: object | None = None, + ) -> None: + raise NotImplementedError + @abstractmethod def forward( self, @@ -193,9 +246,9 @@ def forward( k_pe: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: T, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 06f13044d572..05d0159d0861 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -4,12 +4,14 @@ import enum +from vllm.utils.import_utils import resolve_obj_by_qualname + class _Backend(enum.Enum): FLASH_ATTN = enum.auto() TRITON_ATTN = enum.auto() XFORMERS = enum.auto() - ROCM_FLASH = enum.auto() + ROCM_ATTN = enum.auto() ROCM_AITER_MLA = enum.auto() ROCM_AITER_FA = enum.auto() # used for ViT attn backend TORCH_SDPA = enum.auto() @@ -18,11 +20,91 @@ class _Backend(enum.Enum): TRITON_MLA = enum.auto() CUTLASS_MLA = enum.auto() FLASHMLA = enum.auto() + FLASHMLA_SPARSE = enum.auto() FLASH_ATTN_MLA = enum.auto() PALLAS = enum.auto() IPEX = enum.auto() NO_ATTENTION = enum.auto() FLEX_ATTENTION = enum.auto() TREE_ATTN = enum.auto() - ROCM_ATTN = enum.auto() ROCM_AITER_UNIFIED_ATTN = enum.auto() + + +BACKEND_MAP = { + _Backend.FLASH_ATTN: "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", # noqa: E501 + _Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501 + _Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501 + _Backend.ROCM_ATTN: "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend", # noqa: E501 + _Backend.ROCM_AITER_MLA: "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend", # noqa: E501 + _Backend.ROCM_AITER_FA: "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend", # noqa: E501 + _Backend.TORCH_SDPA: "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend", # noqa: E501 + _Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend", # noqa: E501 + _Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501 + _Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501 + _Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501 + _Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", # noqa: E501 + _Backend.FLASHMLA_SPARSE: "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend", # noqa: E501 + _Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501 + _Backend.PALLAS: "vllm.v1.attention.backends.pallas.PallasAttentionBackend", # noqa: E501 + _Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501 + _Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", # noqa: E501 + _Backend.ROCM_AITER_UNIFIED_ATTN: "vllm.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend", # noqa: E501 +} + + +def register_attn_backend(backend: _Backend, class_path: str | None = None): + """ + Decorator: register a custom attention backend into BACKEND_MAPPING. + - If class_path is provided, use it. + - Otherwise, auto-generate from the class object. + Validation: only checks if 'backend' is a valid _Backend enum member. + Overwriting existing mappings is allowed. This enables other hardware + platforms to plug in custom out-of-tree backends. + """ + if not isinstance(backend, _Backend): + raise ValueError(f"{backend} is not a valid _Backend enum value.") + + def decorator(cls): + path = class_path or f"{cls.__module__}.{cls.__qualname__}" + BACKEND_MAP[backend] = path + return cls + + return decorator + + +def backend_to_class_str(backend: _Backend) -> str: + """Get the backend class string + + Args: + backend: The backend enum value + + Returns: + The backend class string + """ + return BACKEND_MAP[backend] + + +def backend_to_class(backend: _Backend) -> type: + """Get the backend class. + + Args: + backend: The backend enum value + + Returns: + The backend class + """ + backend_class_name = backend_to_class_str(backend) + return resolve_obj_by_qualname(backend_class_name) + + +def backend_name_to_enum(backend_name: str) -> _Backend | None: + """ + Convert a string backend name to a _Backend enum value. + + Returns: + _Backend: enum value if backend_name is a valid in-tree type + None: otherwise it's an invalid in-tree type or an out-of-tree platform + is loaded. + """ + assert backend_name is not None + return _Backend[backend_name] if backend_name in _Backend.__members__ else None diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 46a87bdd1f7e..4c7fa477b52b 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -3,7 +3,6 @@ """Attention backend utils""" from dataclasses import dataclass -from typing import Optional from vllm.config import ModelConfig from vllm.logger import init_logger @@ -15,7 +14,7 @@ @dataclass class MLADims: - q_lora_rank: Optional[int] + q_lora_rank: int | None kv_lora_rank: int qk_nope_head_dim: int qk_rope_head_dim: int diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 6994debd4589..7544daa3aff7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -2,7 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer.""" -from typing import Callable, Optional +from collections.abc import Callable +from typing import cast import torch import torch.nn as nn @@ -10,11 +11,13 @@ import vllm.envs as envs from vllm.attention import AttentionType -from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.registry import _Backend -from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl +from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.selector import get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config +from vllm.config.multimodal import MultiModalConfig +from vllm.config.vllm import VllmConfig from vllm.distributed.kv_transfer import ( get_kv_transfer_group, has_kv_transfer_group, @@ -23,21 +26,30 @@ from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import current_platform -from vllm.utils import GiB_bytes, direct_register_custom_op +from vllm.utils.torch_utils import ( + direct_register_custom_op, + kv_cache_dtype_str_to_dtype, +) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheSpec, + MLAAttentionSpec, + SlidingWindowSpec, +) +FP8_DTYPE = current_platform.fp8_dtype() logger = init_logger(__name__) USE_XFORMERS_OPS = None -try: - tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe,) -except AttributeError: - tag_cudagraph_unsafe = () # type: ignore[assignment] def check_xformers_availability(): @@ -81,12 +93,15 @@ def check_upstream_fa_availability(dtype: torch.dtype): def maybe_get_vit_flash_attn_backend( - attn_backend: _Backend, use_upstream_fa: bool + attn_backend: _Backend, + use_upstream_fa: bool, + attn_backend_override: _Backend | None = None, ) -> tuple[_Backend, Callable]: if ( attn_backend != _Backend.FLASH_ATTN and attn_backend != _Backend.ROCM_AITER_FA and check_upstream_fa_availability(torch.get_default_dtype()) + and attn_backend_override is None ): attn_backend = _Backend.FLASH_ATTN use_upstream_fa = True @@ -108,6 +123,69 @@ def maybe_get_vit_flash_attn_backend( return attn_backend, flash_attn_varlen_func +def _init_kv_cache_quant( + layer: nn.Module, + quant_config: QuantizationConfig | None, + prefix: str, + kv_cache_dtype: str, + calculate_kv_scales: bool, +) -> None: + """Initializes KV cache scaling factors and quantization method. + + This helper function sets up the KV cache quantization attributes that are + shared between Attention and MLAAttention layers. It initializes scale + tensors for query, key, value, and probability, and configures the + quantization method if applicable. + + Args: + layer: The attention layer instance to initialize. + quant_config: Optional quantization configuration. + prefix: Layer name prefix for quantization method lookup. + kv_cache_dtype: The KV cache data type string. + calculate_kv_scales: Whether to calculate KV scales dynamically. + """ + # The default k/v_scale is set to 1.0. This is ignored + # when kv-cache is not fp8, and should be used with + # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we + # expect the pre-quantized k/v_scale to be loaded along + # with the model weights. + layer.kv_cache_dtype = kv_cache_dtype + layer.calculate_kv_scales = calculate_kv_scales + layer._k_scale = torch.tensor(1.0, dtype=torch.float32) + layer._v_scale = torch.tensor(1.0, dtype=torch.float32) + layer._q_scale = torch.tensor(1.0, dtype=torch.float32) + layer._prob_scale = torch.tensor(1.0, dtype=torch.float32) + + # We also keep q/k/v_scale on host (cpu) memory for attention + # backends that require the scales to be on host instead of on device. + # e.g. Flashinfer + layer._q_scale_float = 1.0 + layer._k_scale_float = 1.0 + layer._v_scale_float = 1.0 + + # The output scale on host memory. This should be the input scale of + # the quant op after this attention layer. + layer._o_scale_float = None + + quant_method = ( + quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None + ) + if quant_method is not None and not isinstance( + quant_method, UnquantizedLinearMethod + ): + assert isinstance(quant_method, BaseKVCacheMethod) + # TODO (mgoin): kv cache dtype should be specified in the FP8 + # checkpoint config and become the "auto" behavior + if kv_cache_dtype == "fp8_e5m2": + raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.") + # If quantization is enabled, we make "k_scale" and "v_scale" + # parameters so that it can be loaded from the model checkpoint. + # The k/v_scale will then be converted back to native float32 + # values after weight loading. + layer.quant_method = quant_method + layer.quant_method.create_weights(layer) + + class Attention(nn.Module, AttentionLayerBase): """Attention layer. @@ -125,18 +203,16 @@ def __init__( num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[list[float]] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - logits_soft_cap: Optional[float] = None, - per_layer_sliding_window: Optional[int] = None, - use_mla: bool = False, - use_sparse: bool = False, + num_kv_heads: int | None = None, + alibi_slopes: list[float] | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + logits_soft_cap: float | None = None, + per_layer_sliding_window: int | None = None, prefix: str = "", attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - attn_backend: Optional[type[AttentionBackend]] = None, + kv_sharing_target_layer_name: str | None = None, + attn_backend: type[AttentionBackend] | None = None, **extra_impl_args, ) -> None: """ @@ -153,6 +229,7 @@ def __init__( else: sliding_window = None + vllm_config = get_current_vllm_config() if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size @@ -161,65 +238,26 @@ def __init__( kv_cache_dtype = "auto" block_size = 16 calculate_kv_scales = False + self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype( + kv_cache_dtype, vllm_config.model_config + ) if num_kv_heads is None: num_kv_heads = num_heads assert num_heads % num_kv_heads == 0, ( f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})" ) - # The default k/v_scale is set to 1.0. This is ignored - # when kv-cache is not fp8, and should be used with - # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we - # expect the pre-quantized k/v_scale to be loaded along - # with the model weights. - self.kv_cache_dtype = kv_cache_dtype - self.calculate_kv_scales = calculate_kv_scales - self._k_scale = torch.tensor(1.0, dtype=torch.float32) - self._v_scale = torch.tensor(1.0, dtype=torch.float32) - # FlashAttn doesn't support quantizing the kv-cache only - # but requires q to be quantized as well. - self._q_scale = torch.tensor(1.0, dtype=torch.float32) - self._prob_scale = torch.tensor(1.0, dtype=torch.float32) - - # We also keep q/k/v_scale on host (cpu) memory for attention - # backends that require the scales to be on host instead of on device. - # e.g. Flashinfer - self._q_scale_float = 1.0 - self._k_scale_float = 1.0 - self._v_scale_float = 1.0 - - # The output scale on host memory. This should be the input scale of - # the quant op after this attention layer. - self._o_scale_float: Optional[float] = None - - self.use_mla = use_mla - self.use_sparse = use_sparse + # Initialize KV cache quantization attributes + _init_kv_cache_quant( + self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales + ) + self.num_heads = num_heads self.head_size = head_size self.num_kv_heads = num_kv_heads self.sliding_window = sliding_window self.has_sink = extra_impl_args.get("sinks") is not None - quant_method = ( - quant_config.get_quant_method(self, prefix=prefix) if quant_config else None - ) - if quant_method is not None and not isinstance( - quant_method, UnquantizedLinearMethod - ): - assert isinstance(quant_method, BaseKVCacheMethod) - # TODO (mgoin): kv cache dtype should be specified in the FP8 - # checkpoint config and become the "auto" behavior - if self.kv_cache_dtype == "fp8_e5m2": - raise ValueError( - "fp8_e5m2 kv-cache is not supported with fp8 checkpoints." - ) - # If quantization is enabled, we make "k_scale" and "v_scale" - # parameters so that it can be loaded from the model checkpoint. - # The k/v_scale will then be converted back to native float32 - # values after weight loading. - self.quant_method = quant_method - self.quant_method.create_weights(self) - # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() @@ -229,9 +267,8 @@ def __init__( dtype, kv_cache_dtype, block_size, - use_mla=use_mla, + use_mla=False, has_sink=self.has_sink, - use_sparse=use_sparse, ) else: self.attn_backend = attn_backend @@ -260,7 +297,7 @@ def __init__( self.use_direct_call = not current_platform.opaque_attention_op() self.use_output = self.attn_backend.accept_output_buffer - compilation_config = get_current_vllm_config().compilation_config + compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self @@ -280,36 +317,19 @@ def __init__( # this variable will not be accessed if use_direct_call is True self.kv_cache = [ torch.tensor([]) - for _ in range( - get_current_vllm_config().parallel_config.pipeline_parallel_size - ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) ] - try: - self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) - self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) - self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) - except torch.cuda.OutOfMemoryError as e: - logger.error("Failed to initialize attention q/k/v range constants: %s", e) - if torch.cuda.is_available(): - logger.debug("CUDA device: %s", torch.cuda.current_device()) - logger.debug( - "Allocated: %.2f GiB", torch.cuda.memory_allocated() / GiB_bytes - ) - logger.debug( - "Reserved: %.2f GiB", torch.cuda.memory_reserved() / GiB_bytes - ) - raise RuntimeError( - "Failed to initialize q/k/v range constants. " - "This may be caused by insufficient memory to allocate " - "kv cache." - ) from e + # Initialize q/k/v range constants. + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) # for attn backends supporting query quantization self.query_quant = None if ( self.kv_cache_dtype.startswith("fp8") - and self.attn_backend.supports_quant_query_input + and self.impl.supports_quant_query_input() ): self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) @@ -321,7 +341,7 @@ def forward( # For some alternate attention backends like MLA the attention output # shape does not match the query shape, so we optionally let the model # definition specify the output tensor shape. - output_shape: Optional[torch.Size] = None, + output_shape: torch.Size | None = None, ) -> torch.Tensor: """ The KV cache is stored inside this class and is accessed via @@ -334,7 +354,6 @@ def forward( """ if self.calculate_kv_scales: torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name) - output_dtype = query.dtype if self.query_quant is not None: # quantizing with a simple torch operation enables @@ -343,25 +362,24 @@ def forward( # Otherwise queries are quantized using custom ops # which causes decoding overheads assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} - query, _ = self.query_quant(query, self._q_scale) + + # check if query quantization is supported + if self.impl.supports_quant_query_input(): + query, _ = self.query_quant(query, self._q_scale) if self.use_output: output_shape = output_shape if output_shape is not None else query.shape - output = torch.zeros(output_shape, dtype=output_dtype, device=query.device) + output = torch.empty(output_shape, dtype=output_dtype, device=query.device) hidden_size = output_shape[-1] - # We skip reshaping query, key and value tensors for the MLA - # backend since these tensors have different semantics and are - # processed differently. - if not self.use_mla: - # Reshape the query, key, and value tensors. - # NOTE(woosuk): We do this outside the custom op to minimize the - # CPU overheads from the non-CUDA-graph regions. - query = query.view(-1, self.num_heads, self.head_size) - output = output.view(-1, self.num_heads, self.head_size) - if key is not None: - key = key.view(-1, self.num_kv_heads, self.head_size) - if value is not None: - value = value.view(-1, self.num_kv_heads, self.head_size) + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the + # CPU overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size) if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -410,20 +428,35 @@ def extra_repr(self) -> str: return s def process_weights_after_loading(self, act_dtype: torch.dtype): - if hasattr(self.impl, "process_weights_after_loading"): - self.impl.process_weights_after_loading(act_dtype) - - # FlashInfer requires attention sinks to be float32 - if self.backend == _Backend.FLASHINFER and hasattr(self.impl, "sinks"): - from vllm.v1.attention.backends.flashinfer import FlashInferImpl - - assert isinstance(self.impl, FlashInferImpl) - if self.impl.sinks is not None and self.impl.sinks.dtype != torch.float32: - self.impl.sinks = self.impl.sinks.to(torch.float32) + self.impl.process_weights_after_loading(act_dtype) def get_attn_backend(self) -> type[AttentionBackend]: return self.attn_backend + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + # Block size may get updated after model loading, refresh it + block_size = vllm_config.cache_config.block_size + # Should not be called for enc-dec or encoder-only attention. + assert self.attn_type == AttentionType.DECODER + if self.sliding_window is not None: + assert not vllm_config.model_config.use_mla, ( + "MLA is not supported for slidingwindow" + ) + return SlidingWindowSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + dtype=self.kv_cache_torch_dtype, + sliding_window=self.sliding_window, + ) + else: + return FullAttentionSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + dtype=self.kv_cache_torch_dtype, + ) + class MultiHeadAttention(nn.Module): """Multi-headed attention without any cache, used for ViT.""" @@ -433,10 +466,11 @@ def __init__( num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, + num_kv_heads: int | None = None, # This has no effect, it is only here to make it easier to swap # between Attention and MultiHeadAttention prefix: str = "", + multimodal_config: MultiModalConfig | None = None, ) -> None: super().__init__() self.num_heads = num_heads @@ -456,7 +490,14 @@ def __init__( dtype = torch.get_default_dtype() # Determine the attention backend - backend = get_vit_attn_backend(head_size=head_size, dtype=dtype) + attn_backend_override = None + if multimodal_config is not None: + attn_backend_override = multimodal_config.mm_encoder_attn_backend + backend = get_vit_attn_backend( + head_size=head_size, + dtype=dtype, + attn_backend_override=attn_backend_override, + ) # Some auto-selected backends can be upgraded # to upstream flash attention if available. @@ -484,6 +525,7 @@ def __init__( maybe_get_vit_flash_attn_backend( self.attn_backend, use_upstream_fa, + attn_backend_override=attn_backend_override, ) ) @@ -570,6 +612,218 @@ def forward( return out.reshape(bsz, q_len, -1) +class MLAAttention(nn.Module, AttentionLayerBase): + """Multi-Head Latent Attention layer. + + This class takes query, and compressed key/value tensors as input. + The class does the following: + + 1. Store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention. + 3. Return the output tensor. + """ + + def __init__( + self, + num_heads: int, + scale: float, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int | None, + kv_lora_rank: int, + kv_b_proj: ColumnParallelLinear, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_sparse: bool = False, + indexer: object | None = None, + **extra_impl_args, + ): + super().__init__() + self.num_heads = num_heads + self.scale = scale + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.head_size = kv_lora_rank + qk_rope_head_dim + self.layer_name = prefix + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + calculate_kv_scales = cache_config.calculate_kv_scales + else: + kv_cache_dtype = "auto" + block_size = 16 + calculate_kv_scales = False + + # Initialize KV cache quantization attributes + _init_kv_cache_quant( + self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales + ) + + dtype = torch.get_default_dtype() + self.attn_backend = get_attn_backend( + self.head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla=True, + use_sparse=use_sparse, + ) + impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls()) + self.impl = impl_cls( + num_heads=self.num_heads, + head_size=self.head_size, + scale=self.scale, + num_kv_heads=1, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype=self.kv_cache_dtype, + logits_soft_cap=None, + attn_type=AttentionType.DECODER, + kv_sharing_target_layer_name=None, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + kv_b_proj=kv_b_proj, + indexer=indexer, + **extra_impl_args, + ) + + self.use_direct_call = not current_platform.opaque_attention_op() + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + self.kv_cache = [ + torch.tensor([]) + for _ in range( + get_current_vllm_config().parallel_config.pipeline_parallel_size + ) + ] + + self.use_sparse = use_sparse + + # Initialize q/k/v range constants. + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + + def forward( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + output_shape: torch.Size | None = None, + ) -> torch.Tensor: + if self.use_direct_call: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + + # Mirror Attention.forward scale calculation path + if self.calculate_kv_scales and getattr( + attn_metadata, "enable_kv_scales_calculation", False + ): + self.calc_kv_scales(q, kv_c_normed, k_pe) + + if self.attn_backend.accept_output_buffer: + output = torch.empty(output_shape, dtype=q.dtype, device=q.device) + self.impl.forward( + self, + q, + kv_c_normed, + k_pe, + self_kv_cache, + attn_metadata, + output=output, + ) + return output + else: + return self.impl.forward( + self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata + ) + else: + if self.attn_backend.accept_output_buffer: + output = torch.empty(output_shape, dtype=q.dtype, device=q.device) + torch.ops.vllm.unified_mla_attention_with_output( + q, + kv_c_normed, + k_pe, + output, + self.layer_name, + ) + return output + else: + # We can still access forward context to check calculation flag + if self.calculate_kv_scales: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + if getattr(attn_metadata, "enable_kv_scales_calculation", False): + self.calc_kv_scales(q, kv_c_normed, k_pe) + return torch.ops.vllm.unified_mla_attention( + q, + kv_c_normed, + k_pe, + self.layer_name, + ) + + def process_weights_after_loading(self, act_dtype: torch.dtype): + if hasattr(self.impl, "process_weights_after_loading"): + self.impl.process_weights_after_loading(act_dtype) + + def calc_kv_scales( + self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor + ) -> None: + """Optional scale calculation for MLA inputs. + + Mirrors Attention.calc_kv_scales. Not all MLA backends require this + """ + # Use safe defaults if ranges are not present + q_range = getattr(self, "q_range", torch.tensor(1.0)) + k_range = getattr(self, "k_range", torch.tensor(1.0)) + v_range = getattr(self, "v_range", torch.tensor(1.0)) + + self._q_scale.copy_(torch.abs(q).max() / q_range) + # kv_c_normed is the compressed KV representation; use it for k/v + kv_abs_max = torch.abs(kv_c_normed).max() + self._k_scale.copy_(kv_abs_max / k_range) + self._v_scale.copy_(kv_abs_max / v_range) + self._q_scale_float = self._q_scale.item() + self._k_scale_float = self._k_scale.item() + self._v_scale_float = self._v_scale.item() + self.calculate_kv_scales = False + + def get_attn_backend(self) -> type[AttentionBackend]: + return self.attn_backend + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + kv_cache_dtype = kv_cache_dtype_str_to_dtype( + self.kv_cache_dtype, vllm_config.model_config + ) + return MLAAttentionSpec( + block_size=vllm_config.cache_config.block_size, + num_kv_heads=1, + head_size=self.head_size, + dtype=kv_cache_dtype, + cache_dtype_str=vllm_config.cache_config.cache_dtype, + ) + + def wait_for_kv_layer_from_connector(layer_name: str): if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): return @@ -672,7 +926,6 @@ def unified_attention_fake( op_name="unified_attention", op_func=unified_attention, fake_impl=unified_attention_fake, - tags=tag_cudagraph_unsafe, ) @@ -682,8 +935,8 @@ def unified_attention_with_output( value: torch.Tensor, output: torch.Tensor, layer_name: str, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> None: wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() @@ -713,8 +966,8 @@ def unified_attention_with_output_fake( value: torch.Tensor, output: torch.Tensor, layer_name: str, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> None: return @@ -724,5 +977,94 @@ def unified_attention_with_output_fake( op_func=unified_attention_with_output, mutates_args=["output", "output_block_scale"], fake_impl=unified_attention_with_output_fake, - tags=tag_cudagraph_unsafe, +) + + +def unified_mla_attention( + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + wait_for_kv_layer_from_connector(layer_name) + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self: MLAAttention = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata) + + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + return output + + +def unified_mla_attention_fake( + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + return torch.empty_like(q).contiguous() + + +direct_register_custom_op( + op_name="unified_mla_attention", + op_func=unified_mla_attention, + mutates_args=[], + fake_impl=unified_mla_attention_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def unified_mla_attention_with_output( + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + output: torch.Tensor, + layer_name: str, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, +) -> None: + wait_for_kv_layer_from_connector(layer_name) + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self: MLAAttention = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward( + self, + q, + kv_c_normed, + k_pe, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + ) + + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + + +def unified_mla_attention_with_output_fake( + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + output: torch.Tensor, + layer_name: str, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, +) -> None: + return + + +direct_register_custom_op( + op_name="unified_mla_attention_with_output", + op_func=unified_mla_attention_with_output, + mutates_args=["output", "output_block_scale"], + fake_impl=unified_mla_attention_with_output_fake, + dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 3d37e901605f..18422404d08f 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -from typing import ClassVar, Optional +from typing import ClassVar import torch @@ -9,6 +9,7 @@ from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig +from vllm.config.vllm import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.v1.attention.backends.utils import ( AttentionCGSupport, @@ -16,6 +17,7 @@ make_local_attention_virtual_batches, subclass_attention_backend, ) +from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, KVCacheSpec from ..layer import Attention @@ -60,13 +62,14 @@ def __init__( head_size: int, scale: float, attention_chunk_size: int, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[list[float]] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - kv_sharing_target_layer_name: Optional[str] = None, + num_kv_heads: int | None = None, + alibi_slopes: list[float] | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + kv_sharing_target_layer_name: str | None = None, prefix: str = "", ): + self.attention_chunk_size = attention_chunk_size dtype = torch.get_default_dtype() if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype @@ -99,3 +102,13 @@ def __init__( kv_sharing_target_layer_name=kv_sharing_target_layer_name, attn_backend=attn_backend, ) + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + assert self.attention_chunk_size + return ChunkedLocalAttentionSpec( + block_size=vllm_config.cache_config.block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + dtype=self.kv_cache_torch_dtype, + attention_chunk_size=self.attention_chunk_size, + ) diff --git a/vllm/attention/layers/cross_attention.py b/vllm/attention/layers/cross_attention.py index fb7004f86538..a40a66308a66 100644 --- a/vllm/attention/layers/cross_attention.py +++ b/vllm/attention/layers/cross_attention.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools from copy import copy -from typing import Optional import numpy as np import torch @@ -22,7 +21,7 @@ CommonAttentionMetadata, subclass_attention_backend, ) -from vllm.v1.kv_cache_interface import CrossAttentionSpec +from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec logger = init_logger(__name__) @@ -138,8 +137,8 @@ def __init__( num_heads: int, head_size: int, scale: float, - cache_config: Optional[CacheConfig] = None, - attn_type: Optional[str] = None, + cache_config: CacheConfig | None = None, + attn_type: str | None = None, **kwargs, ): dtype = torch.get_default_dtype() @@ -175,3 +174,11 @@ def __init__( attn_type=AttentionType.ENCODER_DECODER, **kwargs, ) + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + return CrossAttentionSpec( + block_size=vllm_config.cache_config.block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + dtype=self.kv_cache_torch_dtype, + ) diff --git a/vllm/attention/layers/encoder_only_attention.py b/vllm/attention/layers/encoder_only_attention.py index f49f195563dc..8d2a046757fe 100644 --- a/vllm/attention/layers/encoder_only_attention.py +++ b/vllm/attention/layers/encoder_only_attention.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools from copy import copy -from typing import Optional import torch @@ -15,10 +14,12 @@ from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig +from vllm.config.vllm import VllmConfig from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, subclass_attention_backend, ) +from vllm.v1.kv_cache_interface import KVCacheSpec @functools.lru_cache @@ -60,8 +61,8 @@ def __init__( num_heads: int, head_size: int, scale: float, - cache_config: Optional[CacheConfig] = None, - attn_type: Optional[str] = None, + cache_config: CacheConfig | None = None, + attn_type: str | None = None, **kwargs, ): dtype = torch.get_default_dtype() @@ -99,3 +100,7 @@ def __init__( attn_type=AttentionType.ENCODER_ONLY, **kwargs, ) + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + # Does not need KV cache + return None diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index 097fbae68cda..b6b7ecd2552a 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -117,14 +117,52 @@ def correct_attn_out( if ctx is None: ctx = CPTritonContext() - lse = torch.empty_like(lses[0]) + # --- Normalize to 3D views --- + if out.ndim == 4 and out.shape[1] == 1: + out = out.squeeze(1) + assert out.ndim == 3, f"expected out [B,H,D] or [B,1,H,D], got {tuple(out.shape)}" + + if lses.ndim == 4 and lses.shape[-1] == 1: + lses = lses.squeeze(-1) + if lses.ndim == 4 and lses.shape[1] == 1: + lses = lses.squeeze(1) + assert lses.ndim == 3, ( + f"expected lses [N,B,H] (optionally with a 1-sized extra dim), " + f"got {tuple(lses.shape)}" + ) + + B, H, D = out.shape + N = lses.shape[0] + + # Strides after we normalized shapes to 3-D views. The kernel computes + # offsets for `vlse_ptr` using lses_stride_B/H, so the output buffer must + # have the same B/H stride layout as a slice of `lses`. + o_sB, o_sH, o_sD = out.stride() + l_sN, l_sB, l_sH = lses.stride() + + # Allocate LSE with the same B/H strides as `lses` so writes land correctly + # even when `lses` is a non-contiguous view (e.g., 4-D to 3-D squeeze). + lse = torch.empty_strided( + (B, H), (l_sB, l_sH), device=lses.device, dtype=lses.dtype + ) - grid = (out.shape[0], out.shape[1], 1) - regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(), cp_rank) - const_args = { - "HEAD_DIM": out.shape[-1], - "N_ROUNDED": lses.shape[0], - } + # Kernel launch config + grid = (B, H, 1) + + regular_args = ( + out, + out, + lses, + lse, + o_sB, + o_sH, + o_sD, + l_sN, + l_sB, + l_sH, + cp_rank, + ) + const_args = {"HEAD_DIM": D, "N_ROUNDED": N} ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args) return out, lse @@ -135,6 +173,7 @@ def cp_lse_ag_out_rs( cp_attn_lse: torch.Tensor, cp_group: GroupCoordinator, ctx: CPTritonContext = None, + return_lse=False, ): """ cp_attn_out: [ B, H, D ] @@ -154,8 +193,15 @@ def cp_lse_ag_out_rs( cp_attn_lse = cp_attn_lse.contiguous() lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) - out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) + out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) + assert out.is_contiguous() out = cp_group.reduce_scatter(out, dim=1) + + if return_lse: + cp_num_heads = lse.shape[1] // cp_group.world_size + cp_rank = cp_group.rank_in_group + lse = lse[:, cp_num_heads * cp_rank : cp_num_heads * (cp_rank + 1)] + return out, lse return out diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 0bf354a95b1c..d8ab0b9097ef 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py -from typing import Optional import torch @@ -31,7 +30,7 @@ _flashmla_extension_C_AVAILABLE = False -def _is_flashmla_available() -> tuple[bool, Optional[str]]: +def _is_flashmla_available() -> tuple[bool, str | None]: if not _flashmla_C_AVAILABLE: return ( False, @@ -49,7 +48,7 @@ def _is_flashmla_available() -> tuple[bool, Optional[str]]: return True, None -def is_flashmla_dense_supported() -> tuple[bool, Optional[str]]: +def is_flashmla_dense_supported() -> tuple[bool, str | None]: """ Return: is_supported_flag, unsupported_reason (optional). """ @@ -61,7 +60,7 @@ def is_flashmla_dense_supported() -> tuple[bool, Optional[str]]: return True, None -def is_flashmla_sparse_supported() -> tuple[bool, Optional[str]]: +def is_flashmla_sparse_supported() -> tuple[bool, str | None]: """ Return: is_supported_flag, unsupported_reason (optional). """ @@ -80,9 +79,9 @@ def get_mla_metadata( cache_seqlens: torch.Tensor, num_q_tokens_per_head_k: int, num_heads_k: int, - num_heads_q: Optional[int] = None, + num_heads_q: int | None = None, is_fp8_kvcache: bool = False, - topk: Optional[int] = None, + topk: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Arguments: @@ -103,6 +102,12 @@ def get_mla_metadata( (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. - num_splits: (batch_size + 1), dtype torch.int32. """ + if is_fp8_kvcache and topk is None: + return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8( + cache_seqlens, + num_q_tokens_per_head_k, + num_heads_k, + ) return torch.ops._flashmla_C.get_mla_decoding_metadata( cache_seqlens, num_q_tokens_per_head_k, @@ -121,12 +126,12 @@ def flash_mla_with_kvcache( head_dim_v: int, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, - softmax_scale: Optional[float] = None, + softmax_scale: float | None = None, causal: bool = False, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, + descale_q: torch.Tensor | None = None, + descale_k: torch.Tensor | None = None, is_fp8_kvcache: bool = False, - indices: Optional[torch.Tensor] = None, + indices: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Arguments: diff --git a/vllm/attention/ops/merge_attn_states.py b/vllm/attention/ops/merge_attn_states.py index 79800eb40766..16106f3c93a6 100644 --- a/vllm/attention/ops/merge_attn_states.py +++ b/vllm/attention/ops/merge_attn_states.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -13,7 +12,7 @@ def merge_attn_states( prefix_lse: torch.Tensor, suffix_output: torch.Tensor, suffix_lse: torch.Tensor, - output_lse: Optional[torch.Tensor] = None, + output_lse: torch.Tensor | None = None, ) -> None: # NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel # is not support for FP8 dtype, fallback to use Triton kernel. diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 4db7d1a3a325..8e010ffba32e 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional import torch @@ -27,7 +26,7 @@ class PagedAttentionMetadata: # (batch_size,). The length of sequences (entire tokens seen so far) per # sequence. - seq_lens_tensor: Optional[torch.Tensor] + seq_lens_tensor: torch.Tensor | None # Maximum sequence length in the batch. 0 if it is prefill-only batch. max_decode_seq_len: int # (batch_size, max_blocks_per_seq). @@ -36,7 +35,7 @@ class PagedAttentionMetadata: # in the kv cache. Each block can contain up to block_size tokens. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. - block_tables: Optional[torch.Tensor] + block_tables: torch.Tensor | None class PagedAttention: @@ -102,7 +101,7 @@ def forward_decode( kv_cache_dtype: str, num_kv_heads: int, scale: float, - alibi_slopes: Optional[torch.Tensor], + alibi_slopes: torch.Tensor | None, k_scale: torch.Tensor, v_scale: torch.Tensor, tp_rank: int = 0, @@ -211,8 +210,8 @@ def forward_prefix( query_start_loc: torch.Tensor, seq_lens_tensor: torch.Tensor, max_query_len: int, - alibi_slopes: Optional[torch.Tensor], - sliding_window: Optional[int], + alibi_slopes: torch.Tensor | None, + sliding_window: int | None, k_scale: torch.Tensor, v_scale: torch.Tensor, ) -> torch.Tensor: diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index c358b5971f86..6308f63cc4e7 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer def get_aiter_mla_metadata( @@ -30,9 +29,9 @@ def aiter_mla_decode_fwd( sm_scale: float, qo_indptr: torch.Tensor, max_seqlen_qo: int, - kv_indptr: Optional[torch.Tensor] = None, - kv_indices: Optional[torch.Tensor] = None, - kv_last_page_lens: Optional[torch.Tensor] = None, + kv_indptr: torch.Tensor | None = None, + kv_indices: torch.Tensor | None = None, + kv_last_page_lens: torch.Tensor | None = None, logit_cap: float = 0.0, ): torch.ops.vllm.rocm_aiter_mla_decode_fwd( @@ -55,9 +54,9 @@ def mla_decode_fwd_impl( o: torch.Tensor, qo_indptr: torch.Tensor, max_seqlen_qo: int, - kv_indptr: Optional[torch.Tensor] = None, - kv_indices: Optional[torch.Tensor] = None, - kv_last_page_lens: Optional[torch.Tensor] = None, + kv_indptr: torch.Tensor | None = None, + kv_indices: torch.Tensor | None = None, + kv_last_page_lens: torch.Tensor | None = None, sm_scale: float = 1.0, logit_cap: float = 0.0, ) -> None: @@ -83,9 +82,9 @@ def mla_decode_fwd_fake( o: torch.Tensor, qo_indptr: torch.Tensor, max_seqlen_qo: int, - kv_indptr: Optional[torch.Tensor] = None, - kv_indices: Optional[torch.Tensor] = None, - kv_last_page_lens: Optional[torch.Tensor] = None, + kv_indptr: torch.Tensor | None = None, + kv_indices: torch.Tensor | None = None, + kv_last_page_lens: torch.Tensor | None = None, sm_scale: float = 1.0, logit_cap: float = 0.0, ) -> None: diff --git a/vllm/attention/ops/rocm_aiter_paged_attn.py b/vllm/attention/ops/rocm_aiter_paged_attn.py index 069cfcaf00aa..5c1ce68dde1b 100644 --- a/vllm/attention/ops/rocm_aiter_paged_attn.py +++ b/vllm/attention/ops/rocm_aiter_paged_attn.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import aiter as rocm_aiter import torch @@ -62,7 +61,7 @@ def forward_decode( kv_cache_dtype: str, num_kv_heads: int, scale: float, - alibi_slopes: Optional[torch.Tensor], + alibi_slopes: torch.Tensor | None, k_scale: torch.Tensor, v_scale: torch.Tensor, tp_rank: int = 0, diff --git a/vllm/attention/ops/triton_merge_attn_states.py b/vllm/attention/ops/triton_merge_attn_states.py index d29f92f8cecb..3c87a24afd9c 100644 --- a/vllm/attention/ops/triton_merge_attn_states.py +++ b/vllm/attention/ops/triton_merge_attn_states.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -15,7 +14,7 @@ def merge_attn_states( prefix_lse: torch.Tensor, suffix_output: torch.Tensor, suffix_lse: torch.Tensor, - output_lse: Optional[torch.Tensor] = None, + output_lse: torch.Tensor | None = None, ) -> None: num_tokens = output.shape[0] num_query_heads = output.shape[1] diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 3a5bbb997286..9890d8d80cba 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -6,34 +6,20 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import cache -from typing import Optional, Union import torch import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import _Backend, backend_name_to_enum from vllm.logger import init_logger -from vllm.platforms import current_platform -from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname +from vllm.utils import STR_BACKEND_ENV_VAR +from vllm.utils.import_utils import resolve_obj_by_qualname logger = init_logger(__name__) -def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: - """ - Convert a string backend name to a _Backend enum value. - - Returns: - * _Backend: enum value if backend_name is a valid in-tree type - * None: otherwise it's an invalid in-tree type or an out-of-tree platform is - loaded. - """ - assert backend_name is not None - return _Backend[backend_name] if backend_name in _Backend.__members__ else None - - -def get_env_variable_attn_backend() -> Optional[_Backend]: +def get_env_variable_attn_backend() -> _Backend | None: """ Get the backend override specified by the vLLM attention backend environment variable, if one is specified. @@ -54,10 +40,10 @@ def get_env_variable_attn_backend() -> Optional[_Backend]: # # THIS SELECTION TAKES PRECEDENCE OVER THE # VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE -forced_attn_backend: Optional[_Backend] = None +forced_attn_backend: _Backend | None = None -def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None: +def global_force_attn_backend(attn_backend: _Backend | None) -> None: """ Force all attention operations to use a specified backend. @@ -72,7 +58,7 @@ def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None: forced_attn_backend = attn_backend -def get_global_forced_attn_backend() -> Optional[_Backend]: +def get_global_forced_attn_backend() -> _Backend | None: """ Get the currently-forced choice of attention backend, or None if auto-selection is currently enabled. @@ -91,7 +77,7 @@ def __bool__(self) -> bool: def is_attn_backend_supported( - attn_backend: Union[str, type[AttentionBackend]], + attn_backend: str | type[AttentionBackend], head_size: int, dtype: torch.dtype, *, @@ -141,7 +127,7 @@ def is_attn_backend_supported( def get_attn_backend( head_size: int, dtype: torch.dtype, - kv_cache_dtype: Optional[str], + kv_cache_dtype: str | None, block_size: int, use_mla: bool = False, has_sink: bool = False, @@ -168,7 +154,7 @@ def get_attn_backend( def _cached_get_attn_backend( head_size: int, dtype: torch.dtype, - kv_cache_dtype: Optional[str], + kv_cache_dtype: str | None, block_size: int, use_v1: bool = False, use_mla: bool = False, @@ -181,12 +167,12 @@ def _cached_get_attn_backend( # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND # ENVIRONMENT VARIABLE. selected_backend = None - backend_by_global_setting: Optional[_Backend] = get_global_forced_attn_backend() + backend_by_global_setting: _Backend | None = get_global_forced_attn_backend() if backend_by_global_setting is not None: selected_backend = backend_by_global_setting else: # Check the environment variable and override if specified - backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + backend_by_env_var: str | None = envs.VLLM_ATTENTION_BACKEND if backend_by_env_var is not None: if backend_by_env_var.endswith("_VLLM_V1"): logger.warning( @@ -205,6 +191,8 @@ def _cached_get_attn_backend( ) # get device-specific attn_backend + from vllm.platforms import current_platform + attention_cls = current_platform.get_attn_backend_cls( selected_backend, head_size, diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index e13afd46ee96..b92b822c1d19 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional from vllm import envs from vllm.logger import init_logger @@ -21,7 +20,7 @@ get_scheduler_metadata = ops.get_scheduler_metadata -def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: +def get_flash_attn_version(requires_alibi: bool = False) -> int | None: # import here to avoid circular dependencies from vllm.platforms import current_platform diff --git a/vllm/beam_search.py b/vllm/beam_search.py index e0ba863b9210..fcd2d1f0e01a 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional from vllm.logprobs import Logprob from vllm.lora.request import LoRARequest @@ -22,13 +22,13 @@ class BeamSearchSequence: # The tokens include the prompt. tokens: list[int] logprobs: list[dict[int, Logprob]] - lora_request: Optional[LoRARequest] = None + lora_request: LoRARequest | None = None cum_logprob: float = 0.0 - text: Optional[str] = None - finish_reason: Optional[str] = None - stop_reason: Union[int, str, None] = None + text: str | None = None + finish_reason: str | None = None + stop_reason: int | str | None = None multi_modal_data: Optional["MultiModalDataDict"] = None - mm_processor_kwargs: Optional[dict[str, Any]] = None + mm_processor_kwargs: dict[str, Any] | None = None @dataclass @@ -45,8 +45,8 @@ class BeamSearchInstance: def __init__( self, prompt_tokens: list[int], - lora_request: Optional[LoRARequest] = None, - logprobs: Optional[list[dict[int, Logprob]]] = None, + lora_request: LoRARequest | None = None, + logprobs: list[dict[int, Logprob]] | None = None, **kwargs, ): self.beams: list[BeamSearchSequence] = [ diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 7ffc21905924..eb8cd64c34ba 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -21,13 +21,13 @@ import math import random from abc import ABC, abstractmethod -from collections.abc import Iterator, Mapping +from collections.abc import Callable, Iterator, Mapping from contextlib import suppress from copy import deepcopy from dataclasses import dataclass from functools import cache from io import BytesIO -from typing import Any, Callable, Optional, Union, cast +from typing import Any, cast import numpy as np from PIL import Image @@ -39,7 +39,7 @@ from vllm.multimodal import MultiModalDataDict from vllm.multimodal.image import convert_image_mode from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import PlaceholderModule +from vllm.utils.import_utils import PlaceholderModule try: from datasets import load_dataset @@ -75,12 +75,12 @@ class SampleRequest: Represents a single inference request for benchmarking. """ - prompt: Union[str, list[str]] + prompt: str | list[str] prompt_len: int expected_output_len: int - multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None - lora_request: Optional[LoRARequest] = None - request_id: Optional[str] = None + multi_modal_data: MultiModalDataDict | dict | list[dict] | None = None + lora_request: LoRARequest | None = None + request_id: str | None = None # ----------------------------------------------------------------------------- @@ -94,7 +94,7 @@ class BenchmarkDataset(ABC): def __init__( self, - dataset_path: Optional[str] = None, + dataset_path: str | None = None, random_seed: int = DEFAULT_SEED, disable_shuffle: bool = False, **kwargs, @@ -119,7 +119,7 @@ def __init__( def apply_multimodal_chat_transformation( self, prompt: str, - mm_content: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None, + mm_content: MultiModalDataDict | dict | list[dict] | None = None, ) -> list[dict]: """ Transform a prompt and optional multimodal content into a chat format. @@ -154,9 +154,9 @@ def load_data(self) -> None: def get_random_lora_request( self, - max_loras: Optional[int] = None, - lora_path: Optional[str] = None, - ) -> Optional[LoRARequest]: + max_loras: int | None = None, + lora_path: str | None = None, + ) -> LoRARequest | None: """ Optionally select a random LoRA request. @@ -384,7 +384,7 @@ def gen_prompt_decode_to_target_len( target_token_len: int, max_retry: int = 10, add_special_tokens: bool = False, - rng: Optional[np.random.Generator] = None, + rng: np.random.Generator | None = None, ) -> tuple[str, list[int]]: """ Ensure decoded-then-encoded prompt length matches the target token length. @@ -478,6 +478,22 @@ def sample( batchsize: int = 1, **kwargs, ) -> list[SampleRequest]: + # validate total input tokens (prefix + sampled) is at least 1. + num_special = int(tokenizer.num_special_tokens_to_add()) + real_input_len = max(0, int(input_len) - num_special) + min_sampled_input = math.floor(real_input_len * (1.0 - float(range_ratio))) + min_total_input = int(prefix_len) + min_sampled_input + if min_total_input < 1: + raise ValueError( + "--random-input-len is too small: with tokenizer special " + f"tokens {num_special} and --random-range-ratio {range_ratio}, " + "the minimum possible total input tokens (prefix + sampled) is " + f"{min_total_input}. Increase --random-input-len and/or " + "--random-prefix-len, or decrease --random-range-ratio so that " + "prefix_len + floor(max(0, random_input_len - num_special)) " + "* (1 - range_ratio) >= 1." + ) + input_lens, output_lens, offsets = self.get_sampling_params( num_requests, range_ratio, input_len, output_len, tokenizer ) @@ -572,6 +588,7 @@ def get_sampling_params( # Ensure the lower bound for output length is at least 1 to # prevent sampling 0 tokens. output_low = max(output_low, 1) + output_high = max(output_high, 1) if input_low > input_high: raise ValueError( @@ -638,6 +655,112 @@ def generate_token_sequence( return prompt, total_input_len, token_mismatch +# ----------------------------------------------------------------------------- +# Random Dataset Implementation (Synthetic Data) +# ----------------------------------------------------------------------------- + + +class RandomDatasetForReranking(RandomDataset): + """ + Random dataset specialized for the needs of scoring: + - Batches of inputs + - Inputs composed of pairs + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", + range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO, + input_len: int = RandomDataset.DEFAULT_INPUT_LEN, + batchsize: int = 1, + is_reranker: bool = True, + **kwargs, + ) -> list[SampleRequest]: + n_sep_tokens = int(is_reranker) + + query_len_param = (input_len // 2) - n_sep_tokens if is_reranker else input_len + + query_lens, _, query_offsets = self.get_sampling_params( + 1, range_ratio, query_len_param, 0, tokenizer + ) + + query_len = int(query_lens[0]) + + if not is_reranker: + assert num_requests > 1 and batchsize > 1 + num_requests -= 1 + batchsize -= 1 + doc_len_param = input_len + else: + doc_len_param = input_len - query_len - n_sep_tokens + + doc_lens, _, doc_offsets = self.get_sampling_params( + num_requests, range_ratio, doc_len_param, 0, tokenizer + ) + vocab_size = tokenizer.vocab_size + + query_prompt, query_input_len, token_mismatch_total = ( + self.generate_token_sequence( + tokenizer=tokenizer, + prefix_token_ids=[], + prefix_len=0, + vocab_size=vocab_size, + input_len=query_len, + offset=int(query_offsets[0]), + index=0, + ) + ) + + requests = [] + for i in range(num_requests): + prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501 + tokenizer=tokenizer, + prefix_token_ids=[], + prefix_len=0, + vocab_size=vocab_size, + input_len=int(doc_lens[i]), + offset=int(doc_offsets[i]), + index=i + 1, + ) + token_mismatch_total += token_mismatch + requests.append((prompt, total_input_len)) + + batch_requests = [] + # Create batched requests + for i in range(0, num_requests, batchsize): + batch = requests[i : i + batchsize] + query_contrib = ( + (query_input_len + n_sep_tokens) * len(batch) + if is_reranker + else query_input_len + ) + batch_requests.append( + SampleRequest( + prompt=[query_prompt] + [req[0] for req in batch], + prompt_len=query_contrib + sum(req[1] for req in batch), + expected_output_len=0, + request_id=request_id_prefix + str(i // batchsize), + ) + ) + + if token_mismatch_total != 0: + logger.warning( + "Across all generated prompts, there were %d %s tokens " + "than expected after decoding and re-encoding. This is " + "expected due to the imperfect nature of the sampling " + "procedure.", + abs(token_mismatch_total), + "more" if token_mismatch_total > 0 else "fewer", + ) + + return batch_requests + + # ----------------------------------------------------------------------------- # MultiModalDataset Implementation # ----------------------------------------------------------------------------- @@ -1054,9 +1177,9 @@ def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, - lora_path: Optional[str] = None, - max_loras: Optional[int] = None, - output_len: Optional[int] = None, + lora_path: str | None = None, + max_loras: int | None = None, + output_len: int | None = None, enable_multimodal_chat: bool = False, request_id_prefix: str = "", no_oversample: bool = False, @@ -1149,6 +1272,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "sonnet", "random", "random-mm", + "random-rerank", "hf", "custom", "prefix_repetition", @@ -1292,6 +1416,14 @@ def add_dataset_parser(parser: FlexibleArgumentParser): default=1, help=("Batch size for random sampling. Only used for embeddings benchmark."), ) + random_group.add_argument( + "--no-reranker", + action="store_true", + help=( + "Whether the model supports reranking natively." + " Only used for reranker benchmark." + ), + ) # random multimodal dataset options random_mm_group = parser.add_argument_group( @@ -1584,7 +1716,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: if dataset_class.IS_MULTIMODAL and not ( args.backend in ("openai-chat", "openai-audio") - or "openai-embeddings-" in args.backend + or "embeddings-" in args.backend ): # multi-modal benchmark is only available on OpenAI Chat # endpoint-type. @@ -1678,6 +1810,19 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: request_id_prefix=args.request_id_prefix, no_oversample=args.no_oversample, ), + "random-rerank": lambda: RandomDatasetForReranking( + random_seed=args.seed, + dataset_path=args.dataset_path, + disable_shuffle=args.disable_shuffle, + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + input_len=args.random_input_len, + range_ratio=args.random_range_ratio, + request_id_prefix=args.request_id_prefix, + batchsize=args.random_batch_size, + is_reranker=not args.no_reranker, + ), "prefix_repetition": lambda: PrefixRepetitionRandomDataset( random_seed=args.seed, dataset_path=args.dataset_path, @@ -1766,9 +1911,9 @@ def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, - lora_path: Optional[str] = None, - max_loras: Optional[int] = None, - output_len: Optional[int] = None, + lora_path: str | None = None, + max_loras: int | None = None, + output_len: int | None = None, enable_multimodal_chat: bool = False, skip_chat_template: bool = False, request_id_prefix: str = "", @@ -1997,8 +2142,8 @@ def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, - max_loras: Optional[int] = None, - lora_path: Optional[str] = None, + max_loras: int | None = None, + lora_path: str | None = None, request_id_prefix: str = "", no_oversample: bool = False, **kwargs, @@ -2034,15 +2179,15 @@ def sample( class HuggingFaceDataset(BenchmarkDataset): """Base class for datasets hosted on HuggingFace.""" - SUPPORTED_DATASET_PATHS: Union[set[str], dict[str, Callable]] = set() + SUPPORTED_DATASET_PATHS: set[str] | dict[str, Callable] = set() def __init__( self, dataset_path: str, dataset_split: str, no_stream: bool = False, - dataset_subset: Optional[str] = None, - hf_name: Optional[str] = None, + dataset_subset: str | None = None, + hf_name: str | None = None, **kwargs, ) -> None: super().__init__(dataset_path=dataset_path, **kwargs) @@ -2083,7 +2228,7 @@ def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, - output_len: Optional[int] = None, + output_len: int | None = None, enable_multimodal_chat: bool = False, request_id_prefix: str = "", no_oversample: bool = False, @@ -2152,7 +2297,7 @@ def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, - output_len: Optional[int] = None, + output_len: int | None = None, enable_multimodal_chat: bool = False, request_id_prefix: str = "", no_oversample: bool = False, @@ -2206,7 +2351,7 @@ def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, - output_len: Optional[int] = None, + output_len: int | None = None, enable_multimodal_chat: bool = False, request_id_prefix: str = "", no_oversample: bool = False, @@ -2267,7 +2412,7 @@ def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, - output_len: Optional[int] = None, + output_len: int | None = None, enable_multimodal_chat: bool = False, skip_chat_template: bool = False, request_id_prefix: str = "", @@ -2331,7 +2476,7 @@ def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, - output_len: Optional[int] = None, + output_len: int | None = None, enable_multimodal_chat: bool = False, skip_chat_template: bool = False, request_id_prefix: str = "", @@ -2397,7 +2542,7 @@ def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, - output_len: Optional[int] = None, + output_len: int | None = None, skip_chat_template: bool = False, request_id_prefix: str = "", no_oversample: bool = False, @@ -2478,7 +2623,7 @@ def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, - output_len: Optional[int] = None, + output_len: int | None = None, request_id_prefix: str = "", no_oversample: bool = False, **kwargs, @@ -2660,7 +2805,7 @@ def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, - output_len: Optional[int] = None, + output_len: int | None = None, request_id_prefix: str = "", no_oversample: bool = False, **kwargs, @@ -2738,7 +2883,7 @@ def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, - output_len: Optional[int] = None, + output_len: int | None = None, request_id_prefix: str = "", no_oversample: bool = False, **kwargs, @@ -2850,13 +2995,14 @@ def _generate_exact_length_tokens(target_length: int) -> list[int]: requests = [] token_mismatch_total = 0 for _ in range(num_prefixes): - prefix_tokens = _generate_exact_length_tokens(prefix_len) + prefix_tokens, prefix_mismatch = _generate_exact_length_tokens(prefix_len) + token_mismatch_total += prefix_mismatch for _ in range(prompts_per_prefix): - suffix_tokens, token_mistmatch = _generate_exact_length_tokens( + suffix_tokens, suffix_mismatch = _generate_exact_length_tokens( suffix_len ) - token_mismatch_total += token_mistmatch + token_mismatch_total += suffix_mismatch combined_tokens = prefix_tokens + suffix_tokens prompt = tokenizer.decode(combined_tokens) prompt_len = len(combined_tokens) @@ -2902,7 +3048,7 @@ def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, - output_len: Optional[int] = None, + output_len: int | None = None, enable_multimodal_chat: bool = False, request_id_prefix: str = "", no_oversample: bool = False, diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py index 7692697fe768..b4f1751837f4 100644 --- a/vllm/benchmarks/latency.py +++ b/vllm/benchmarks/latency.py @@ -7,7 +7,7 @@ import json import os import time -from typing import Any, Optional +from typing import Any import numpy as np from tqdm import tqdm @@ -127,7 +127,7 @@ def llm_generate(): ), ) - def run_to_completion(profile_dir: Optional[str] = None): + def run_to_completion(profile_dir: str | None = None): if profile_dir: llm.start_profile() llm_generate() diff --git a/vllm/benchmarks/lib/endpoint_request_func.py b/vllm/benchmarks/lib/endpoint_request_func.py index 34dce5edb0c7..ed0fdec25186 100644 --- a/vllm/benchmarks/lib/endpoint_request_func.py +++ b/vllm/benchmarks/lib/endpoint_request_func.py @@ -10,7 +10,7 @@ import traceback from collections.abc import Awaitable from dataclasses import dataclass, field -from typing import Any, Literal, Optional, Protocol, Union +from typing import Any, Literal, Protocol import aiohttp import regex as re @@ -64,19 +64,19 @@ def add_chunk(self, chunk_bytes: bytes) -> list[str]: class RequestFuncInput: """The input for the request function.""" - prompt: str + prompt: str | list[str] api_url: str prompt_len: int output_len: int model: str - model_name: Optional[str] = None - logprobs: Optional[int] = None - extra_headers: Optional[dict] = None - extra_body: Optional[dict] = None - multi_modal_content: Optional[Union[dict, list[dict]]] = None + model_name: str | None = None + logprobs: int | None = None + extra_headers: dict | None = None + extra_body: dict | None = None + multi_modal_content: dict | list[dict] | None = None ignore_eos: bool = False - language: Optional[str] = None - request_id: Optional[str] = None + language: str | None = None + request_id: str | None = None @dataclass @@ -100,14 +100,14 @@ def __call__( self, request_func_input: RequestFuncInput, session: aiohttp.ClientSession, - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, ) -> Awaitable[RequestFuncOutput]: ... def _validate_api_url( api_url: str, api_name: str, - expected_suffixes: Union[str, set[str]], + expected_suffixes: str | set[str], ) -> None: if isinstance(expected_suffixes, str): expected_suffixes = {expected_suffixes} @@ -141,7 +141,7 @@ def _update_headers_common( async def async_request_openai_completions( request_func_input: RequestFuncInput, session: aiohttp.ClientSession, - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, ) -> RequestFuncOutput: """The async request function for the OpenAI Completions API. @@ -279,7 +279,7 @@ def _get_chat_content( async def async_request_openai_chat_completions( request_func_input: RequestFuncInput, session: aiohttp.ClientSession, - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, mm_position: Literal["first", "last"] = "last", ) -> RequestFuncOutput: api_url = request_func_input.api_url @@ -376,7 +376,7 @@ async def async_request_openai_chat_completions( async def async_request_openai_audio( request_func_input: RequestFuncInput, session: aiohttp.ClientSession, - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, ) -> RequestFuncOutput: # Lazy import without PlaceholderModule to avoid vllm dep. import soundfile @@ -484,12 +484,12 @@ def to_bytes(y, sr): return output -async def _run_openai_embeddings( +async def _run_pooling_request( session: aiohttp.ClientSession, api_url: str, payload: dict[str, Any], headers: dict[str, Any], - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, ) -> RequestFuncOutput: output = RequestFuncOutput() st = time.perf_counter() @@ -497,11 +497,18 @@ async def _run_openai_embeddings( try: async with session.post(url=api_url, headers=headers, json=payload) as response: if response.status == 200: - output.latency = time.perf_counter() - st - data = await response.json() + output.ttft = output.latency = time.perf_counter() - st + + if payload.get("encoding_format", "float") == "bytes": + metadata = json.loads(response.headers["metadata"]) + usage = metadata.get("usage", {}) + else: + data = await response.json() + usage = data.get("usage", {}) + output.success = True output.generated_text = "" - output.prompt_len = data.get("usage", {}).get("prompt_tokens", 0) + output.prompt_len = usage.get("prompt_tokens", 0) else: output.success = False output.error = response.reason or "" @@ -517,7 +524,7 @@ async def _run_openai_embeddings( async def async_request_openai_embeddings( request_func_input: RequestFuncInput, session: aiohttp.ClientSession, - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url _validate_api_url(api_url, "OpenAI Embeddings API", "embeddings") @@ -527,6 +534,9 @@ async def async_request_openai_embeddings( if request_func_input.model_name else request_func_input.model, "input": request_func_input.prompt, + # Many embedding models have short context length, + # this is to avoid dropping some of the requests. + "truncate_prompt_tokens": -1, } _update_payload_common(payload, request_func_input) @@ -536,7 +546,7 @@ async def async_request_openai_embeddings( } _update_headers_common(headers, request_func_input) - return await _run_openai_embeddings( + return await _run_pooling_request( session, api_url, payload=payload, @@ -545,26 +555,29 @@ async def async_request_openai_embeddings( ) -async def async_request_openai_embeddings_chat( +async def async_request_vllm_rerank( request_func_input: RequestFuncInput, session: aiohttp.ClientSession, - pbar: Optional[tqdm] = None, - mm_position: Literal["first", "last"] = "last", + pbar: tqdm | None = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url - _validate_api_url(api_url, "OpenAI Embeddings API", "embeddings") + _validate_api_url(api_url, "vLLM score API", "rerank") - content = _get_chat_content(request_func_input, mm_position=mm_position) + assert ( + isinstance(request_func_input.prompt, list) + and len(request_func_input.prompt) > 1 + ) payload = { "model": request_func_input.model_name if request_func_input.model_name else request_func_input.model, - "messages": [ - {"role": "user", "content": content}, - ], + "query": request_func_input.prompt[0], + "documents": request_func_input.prompt[1:], + # Many reranker models have short context length, + # this is to avoid dropping some of the requests. + "truncate_prompt_tokens": -1, } - _update_payload_common(payload, request_func_input) headers = { "Content-Type": "application/json", @@ -572,7 +585,7 @@ async def async_request_openai_embeddings_chat( } _update_headers_common(headers, request_func_input) - return await _run_openai_embeddings( + return await _run_pooling_request( session, api_url, payload=payload, @@ -581,25 +594,41 @@ async def async_request_openai_embeddings_chat( ) -async def async_request_openai_embeddings_clip( +async def async_request_openai_embeddings_chat( request_func_input: RequestFuncInput, session: aiohttp.ClientSession, - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, + mm_position: Literal["first", "last"] = "last", ) -> RequestFuncOutput: - if request_func_input.multi_modal_content: - # Image input - request_func_input.prompt = "" + api_url = request_func_input.api_url + _validate_api_url(api_url, "OpenAI Embeddings API", "embeddings") - # max_model_len=77 is too short for most datasets, - # so by default we truncate the prompt to max_model_len - if request_func_input.extra_body is None: - request_func_input.extra_body = {} - if "truncate_prompt_tokens" not in request_func_input.extra_body: - request_func_input.extra_body["truncate_prompt_tokens"] = -1 + content = _get_chat_content(request_func_input, mm_position=mm_position) - return await async_request_openai_embeddings_chat( - request_func_input, + payload = { + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + "messages": [ + {"role": "user", "content": content}, + ], + # Many embedding models have short context length, + # this is to avoid dropping some of the requests. + "truncate_prompt_tokens": -1, + } + _update_payload_common(payload, request_func_input) + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + _update_headers_common(headers, request_func_input) + + return await _run_pooling_request( session, + api_url, + payload=payload, + headers=headers, pbar=pbar, ) @@ -616,11 +645,13 @@ def _try_extract_request_idx(request_func_input: RequestFuncInput): return None -async def async_request_openai_embeddings_vlm2vec( - request_func_input: RequestFuncInput, - session: aiohttp.ClientSession, - pbar: Optional[tqdm] = None, -) -> RequestFuncOutput: +def _preprocess_clip(request_func_input: RequestFuncInput): + if request_func_input.multi_modal_content: + # Image input + request_func_input.prompt = "" + + +def _preprocess_vlm2vec(request_func_input: RequestFuncInput): if request_func_input.multi_modal_content: request_idx = _try_extract_request_idx(request_func_input) @@ -637,6 +668,28 @@ async def async_request_openai_embeddings_vlm2vec( f"{request_func_input.prompt}" ) + +async def async_request_openai_embeddings_clip( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + _preprocess_clip(request_func_input) + + return await async_request_openai_embeddings_chat( + request_func_input, + session, + pbar=pbar, + ) + + +async def async_request_openai_embeddings_vlm2vec( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + _preprocess_vlm2vec(request_func_input) + return await async_request_openai_embeddings_chat( request_func_input, session, @@ -645,6 +698,61 @@ async def async_request_openai_embeddings_vlm2vec( ) +async def async_request_infinity_embeddings( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + _validate_api_url(api_url, "Infinity Embeddings API", "embeddings") + + payload = { + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + } + + if request_func_input.prompt: + payload["input"] = request_func_input.prompt + else: + mm_content = request_func_input.multi_modal_content + assert isinstance(mm_content, dict) + + mm_type = mm_content["type"] + payload["input"] = mm_content[mm_type]["url"] + payload["modality"] = mm_type.split("_", 1)[0] + + _update_payload_common(payload, request_func_input) + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + _update_headers_common(headers, request_func_input) + + return await _run_pooling_request( + session, + api_url, + payload=payload, + headers=headers, + pbar=pbar, + ) + + +async def async_request_infinity_embeddings_clip( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + _preprocess_clip(request_func_input) + + return await async_request_infinity_embeddings( + request_func_input, + session, + pbar=pbar, + ) + + # TODO: Add more request functions for different API protocols. ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = { "vllm": async_request_openai_completions, @@ -655,6 +763,11 @@ async def async_request_openai_embeddings_vlm2vec( "openai-embeddings-chat": async_request_openai_embeddings_chat, "openai-embeddings-clip": async_request_openai_embeddings_clip, "openai-embeddings-vlm2vec": async_request_openai_embeddings_vlm2vec, + # Infinity embedding server: https://github.com/michaelfeil/infinity + "infinity-embeddings": async_request_infinity_embeddings, + "infinity-embeddings-clip": async_request_infinity_embeddings_clip, + # (Infinity embedding server does not support vlm2vec) + "vllm-rerank": async_request_vllm_rerank, } OPENAI_COMPATIBLE_BACKENDS = [ diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index f061c1479968..71d136d61cea 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -18,6 +18,7 @@ import argparse import asyncio +import contextlib import gc import importlib.util import json @@ -30,7 +31,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Literal, Optional +from typing import Any, Literal import aiohttp import numpy as np @@ -57,12 +58,13 @@ class TaskType(Enum): GENERATION = "generation" - EMBEDDING = "embedding" + POOLING = "pooling" @dataclass class BenchmarkMetrics: completed: int + failed: int total_input: int total_output: int request_throughput: float @@ -96,6 +98,7 @@ class BenchmarkMetrics: @dataclass class EmbedBenchmarkMetrics: completed: int + failed: int total_input: int request_throughput: float total_token_throughput: float @@ -106,9 +109,9 @@ class EmbedBenchmarkMetrics: def _get_current_request_rate( - ramp_up_strategy: Optional[Literal["linear", "exponential"]], - ramp_up_start_rps: Optional[int], - ramp_up_end_rps: Optional[int], + ramp_up_strategy: Literal["linear", "exponential"] | None, + ramp_up_start_rps: int | None, + ramp_up_end_rps: int | None, request_index: int, total_requests: int, request_rate: float, @@ -134,9 +137,9 @@ async def get_request( input_requests: list[SampleRequest], request_rate: float, burstiness: float = 1.0, - ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, - ramp_up_start_rps: Optional[int] = None, - ramp_up_end_rps: Optional[int] = None, + ramp_up_strategy: Literal["linear", "exponential"] | None = None, + ramp_up_start_rps: int | None = None, + ramp_up_end_rps: int | None = None, ) -> AsyncGenerator[tuple[SampleRequest, float], None]: """ Asynchronously generates requests at a specified rate @@ -238,12 +241,15 @@ def calculate_metrics_for_embeddings( """ total_input = 0 completed = 0 + failed = 0 e2els: list[float] = [] for i in range(len(outputs)): if outputs[i].success: e2els.append(outputs[i].latency) completed += 1 total_input += outputs[i].prompt_len + else: + failed += 1 if completed == 0: warnings.warn( @@ -253,6 +259,7 @@ def calculate_metrics_for_embeddings( ) metrics = EmbedBenchmarkMetrics( completed=completed, + failed=failed, total_input=total_input, request_throughput=completed / dur_s, total_token_throughput=total_input / dur_s, @@ -365,6 +372,7 @@ def calculate_metrics( # Find the time range across all successful requests successful_outputs = [output for output in outputs if output.success] + failed_outputs = [output for output in outputs if not output.success] if successful_outputs: min_start_time = min(output.start_time for output in successful_outputs) max_end_time = max( @@ -426,6 +434,7 @@ def calculate_metrics( metrics = BenchmarkMetrics( completed=completed, + failed=len(failed_outputs), total_input=total_input, total_output=sum(actual_output_lens), request_throughput=completed / dur_s, @@ -473,22 +482,23 @@ async def benchmark( model_name: str, tokenizer: PreTrainedTokenizerBase, input_requests: list[SampleRequest], - logprobs: Optional[int], + logprobs: int | None, request_rate: float, burstiness: float, disable_tqdm: bool, + num_warmups: int, profile: bool, selected_percentile_metrics: list[str], selected_percentiles: list[float], ignore_eos: bool, goodput_config_dict: dict[str, float], - max_concurrency: Optional[int], - lora_modules: Optional[Iterable[str]], - extra_headers: Optional[dict], - extra_body: Optional[dict], - ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, - ramp_up_start_rps: Optional[int] = None, - ramp_up_end_rps: Optional[int] = None, + max_concurrency: int | None, + lora_modules: Iterable[str] | None, + extra_headers: dict | None, + extra_body: dict | None, + ramp_up_strategy: Literal["linear", "exponential"] | None = None, + ramp_up_start_rps: int | None = None, + ramp_up_end_rps: int | None = None, ready_check_timeout_sec: int = 600, ): try: @@ -558,10 +568,37 @@ async def benchmark( f"Error: {test_output.error}" ) else: - print("Initial test run completed. Starting main benchmark run...") + print("Initial test run completed.") else: print("Skipping endpoint ready check.") + if num_warmups > 0: + print(f"Warming up with {num_warmups} requests...") + warmup_pbar = None if disable_tqdm else tqdm(total=num_warmups) + warmup_semaphore = ( + asyncio.Semaphore(max_concurrency) + if max_concurrency + else contextlib.nullcontext() + ) + warmup_tasks = [] + + async def warmup_limited_request_func(): + async with warmup_semaphore: + return await request_func( + request_func_input=test_input, session=session, pbar=warmup_pbar + ) + + for _ in range(num_warmups): + request_task = asyncio.create_task(warmup_limited_request_func()) + warmup_tasks.append(request_task) + _ = await asyncio.gather(*warmup_tasks) + + if warmup_pbar is not None: + warmup_pbar.close() + print("Warmup run completed.") + + print("Starting main benchmark run...") + if lora_modules: # For each input request, choose a LoRA module at random. lora_modules = iter( @@ -605,17 +642,13 @@ async def benchmark( pbar = None if disable_tqdm else tqdm(total=len(input_requests)) - # This can be used once the minimum Python version is 3.10 or higher, - # and it will simplify the code in limited_request_func. - # semaphore = (asyncio.Semaphore(max_concurrency) - # if max_concurrency else contextlib.nullcontext()) - semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None + semaphore = ( + asyncio.Semaphore(max_concurrency) + if max_concurrency + else contextlib.nullcontext() + ) async def limited_request_func(request_func_input, session, pbar): - if semaphore is None: - return await request_func( - request_func_input=request_func_input, session=session, pbar=pbar - ) async with semaphore: return await request_func( request_func_input=request_func_input, session=session, pbar=pbar @@ -709,6 +742,7 @@ async def limited_request_func(request_func_input, session, pbar): print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10}".format("Failed requests:", metrics.failed)) if max_concurrency is not None: print("{:<40} {:<10}".format("Maximum request concurrency:", max_concurrency)) if request_rate != float("inf"): @@ -754,6 +788,7 @@ async def limited_request_func(request_func_input, session, pbar): result = { "duration": benchmark_duration, "completed": metrics.completed, + "failed": metrics.failed, "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "request_throughput": metrics.request_throughput, @@ -1032,6 +1067,12 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Specify to disable tqdm progress bar.", ) + parser.add_argument( + "--num-warmups", + type=int, + default=0, + help="Number of warmup requests.", + ) parser.add_argument( "--profile", action="store_true", @@ -1087,10 +1128,12 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--percentile-metrics", type=str, - default="ttft,tpot,itl", + default=None, help="Comma-separated list of selected metrics to report percentils. " "This argument specifies the metrics to report percentiles. " - 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ', + 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' + 'If not specified, defaults to "ttft,tpot,itl" for generative models ' + 'and "e2el" for pooling models.', ) parser.add_argument( "--metric-percentiles", @@ -1188,7 +1231,7 @@ def add_cli_args(parser: argparse.ArgumentParser): default=None, help="The model name used in the API. " "If not specified, the model name will be the " - "same as the ``--model`` argument. ", + "same as the `--model` argument. ", ) parser.add_argument( @@ -1233,6 +1276,15 @@ def add_cli_args(parser: argparse.ArgumentParser): "the ready check will be skipped.", ) + parser.add_argument( + "--extra-body", + help="A JSON string representing extra body parameters to include " + "in each request." + 'Example: \'{"chat_template_kwargs":{"enable_thinking":false}}\'', + type=json.loads, + default=None, + ) + def main(args: argparse.Namespace) -> dict[str, Any]: return asyncio.run(main_async(args)) @@ -1304,7 +1356,11 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: goodput_config_dict = check_goodput_args(args) backend = args.backend - task_type = TaskType.EMBEDDING if "embeddings" in backend else TaskType.GENERATION + task_type = ( + TaskType.POOLING + if "embeddings" in backend or "rerank" in backend + else TaskType.GENERATION + ) # Collect the sampling parameters. if task_type == TaskType.GENERATION: @@ -1330,8 +1386,16 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: if "temperature" not in sampling_params: sampling_params["temperature"] = 0.0 # Default to greedy decoding. + + default_percentile_metrics = "ttft,tpot,itl" else: sampling_params = {} + default_percentile_metrics = "e2el" + + extra_body = args.extra_body or {} + extra_body = {**sampling_params, **extra_body} + + percentile_metrics: str = args.percentile_metrics or default_percentile_metrics # Avoid GC processing "static" data - reduce pause times. gc.collect() @@ -1350,15 +1414,16 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: request_rate=args.request_rate, burstiness=args.burstiness, disable_tqdm=args.disable_tqdm, + num_warmups=args.num_warmups, profile=args.profile, - selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentile_metrics=percentile_metrics.split(","), selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")], ignore_eos=args.ignore_eos, goodput_config_dict=goodput_config_dict, max_concurrency=args.max_concurrency, lora_modules=args.lora_modules, extra_headers=headers, - extra_body=sampling_params, + extra_body=extra_body, ramp_up_strategy=args.ramp_up_strategy, ramp_up_start_rps=args.ramp_up_start_rps, ramp_up_end_rps=args.ramp_up_end_rps, diff --git a/vllm/benchmarks/sweep/__init__.py b/vllm/benchmarks/sweep/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/benchmarks/sweep/param_sweep.py b/vllm/benchmarks/sweep/param_sweep.py new file mode 100644 index 000000000000..986561ed8502 --- /dev/null +++ b/vllm/benchmarks/sweep/param_sweep.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import os +from typing import Any + + +class ParameterSweep(list["ParameterSweepItem"]): + @classmethod + def read_json(cls, filepath: os.PathLike): + with open(filepath, "rb") as f: + records = json.load(f) + + return cls.from_records(records) + + @classmethod + def from_records(cls, records: list[dict[str, object]]): + if not isinstance(records, list): + raise TypeError( + f"The parameter sweep should be a list of dictionaries, " + f"but found type: {type(records)}" + ) + + return cls(ParameterSweepItem.from_record(record) for record in records) + + +class ParameterSweepItem(dict[str, object]): + @classmethod + def from_record(cls, record: dict[str, object]): + if not isinstance(record, dict): + raise TypeError( + f"Each item in the parameter sweep should be a dictionary, " + f"but found type: {type(record)}" + ) + + return cls(record) + + def __or__(self, other: dict[str, Any]): + return type(self)(super().__or__(other)) + + # In JSON, we prefer "_" + def _iter_param_key_candidates(self, param_key: str): + # Inner config arguments are not converted by the CLI + if "." in param_key: + prefix, rest = param_key.split(".", 1) + for prefix_candidate in self._iter_param_key_candidates(prefix): + yield prefix_candidate + "." + rest + + return + + yield param_key + yield param_key.replace("-", "_") + yield param_key.replace("_", "-") + + # In CLI, we prefer "-" + def _iter_cmd_key_candidates(self, param_key: str): + for k in reversed(tuple(self._iter_param_key_candidates(param_key))): + yield "--" + k + + def _normalize_cmd_key(self, param_key: str): + return next(self._iter_cmd_key_candidates(param_key)) + + def has_param(self, param_key: str) -> bool: + return any(k in self for k in self._iter_param_key_candidates(param_key)) + + def apply_to_cmd(self, cmd: list[str]) -> list[str]: + cmd = list(cmd) + + for k, v in self.items(): + for k_candidate in self._iter_cmd_key_candidates(k): + try: + k_idx = cmd.index(k_candidate) + + if isinstance(v, bool): + cmd[k_idx] = self._normalize_cmd_key(k if v else "no-" + k) + else: + cmd[k_idx + 1] = str(v) + + break + except ValueError: + continue + else: + if isinstance(v, bool): + cmd.append(self._normalize_cmd_key(k if v else "no-" + k)) + else: + cmd.extend([self._normalize_cmd_key(k), str(v)]) + + return cmd + + def as_text(self, sep: str = ", ") -> str: + return sep.join(f"{k}={v}" for k, v in self.items()) diff --git a/vllm/benchmarks/sweep/plot.py b/vllm/benchmarks/sweep/plot.py new file mode 100644 index 000000000000..92485c09b416 --- /dev/null +++ b/vllm/benchmarks/sweep/plot.py @@ -0,0 +1,530 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import json +from abc import ABC, abstractmethod +from concurrent.futures import ProcessPoolExecutor +from dataclasses import dataclass +from functools import partial +from pathlib import Path +from types import TracebackType + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +from typing_extensions import Self, override + +from vllm.utils.collection_utils import full_groupby + +from .utils import sanitize_filename + + +@dataclass +class PlotFilterBase(ABC): + var: str + target: str + + @classmethod + def parse_str(cls, s: str): + for op_key in PLOT_FILTERS: + if op_key in s: + key, value = s.split(op_key) + return PLOT_FILTERS[op_key]( + key, + value.removeprefix(op_key).strip("'").strip('"'), + ) + else: + raise ValueError( + f"Invalid operator for plot filter '{s}'. " + f"Valid operators are: {sorted(PLOT_FILTERS)}", + ) + + @abstractmethod + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + """Applies this filter to a DataFrame.""" + raise NotImplementedError + + +@dataclass +class PlotEqualTo(PlotFilterBase): + @override + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + try: + target = float(self.target) + except ValueError: + target = self.target + + return df[df[self.var] == target] + + +@dataclass +class PlotLessThan(PlotFilterBase): + @override + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + return df[df[self.var] < float(self.target)] + + +@dataclass +class PlotLessThanOrEqualTo(PlotFilterBase): + @override + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + return df[df[self.var] <= float(self.target)] + + +@dataclass +class PlotGreaterThan(PlotFilterBase): + @override + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + return df[df[self.var] > float(self.target)] + + +@dataclass +class PlotGreaterThanOrEqualTo(PlotFilterBase): + @override + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + return df[df[self.var] >= float(self.target)] + + +# NOTE: The ordering is important! Match longer op_keys first +PLOT_FILTERS: dict[str, type[PlotFilterBase]] = { + "==": PlotEqualTo, + "<=": PlotLessThanOrEqualTo, + ">=": PlotGreaterThanOrEqualTo, + "<": PlotLessThan, + ">": PlotGreaterThan, +} + + +class PlotFilters(list[PlotFilterBase]): + @classmethod + def parse_str(cls, s: str): + if not s: + return cls() + + return cls(PlotFilterBase.parse_str(e) for e in s.split(",")) + + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + for item in self: + df = item.apply(df) + + return df + + +@dataclass +class PlotBinner: + var: str + bin_size: float + + @classmethod + def parse_str(cls, s: str): + for op_key in PLOT_BINNERS: + if op_key in s: + key, value = s.split(op_key) + return PLOT_BINNERS[op_key](key, float(value.removeprefix(op_key))) + else: + raise ValueError( + f"Invalid operator for plot binner '{s}'. " + f"Valid operators are: {sorted(PLOT_BINNERS)}", + ) + + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + """Applies this binner to a DataFrame.""" + df = df.copy() + df[self.var] = df[self.var] // self.bin_size * self.bin_size + return df + + +PLOT_BINNERS: dict[str, type[PlotBinner]] = { + "%": PlotBinner, +} + + +class PlotBinners(list[PlotBinner]): + @classmethod + def parse_str(cls, s: str): + if not s: + return cls() + + return cls(PlotBinner.parse_str(e) for e in s.split(",")) + + def apply(self, df: pd.DataFrame) -> pd.DataFrame: + for item in self: + df = item.apply(df) + + return df + + +def _json_load_bytes(path: Path) -> list[dict[str, object]]: + with path.open("rb") as f: + return json.load(f) + + +def _get_metric(run_data: dict[str, object], metric_key: str): + try: + return run_data[metric_key] + except KeyError as exc: + raise ValueError(f"Cannot find metric {metric_key!r} in {run_data=}") from exc + + +def _get_group(run_data: dict[str, object], group_keys: list[str]): + return tuple((k, str(_get_metric(run_data, k))) for k in group_keys) + + +def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...]): + parts = list[str]() + if group: + parts.extend(("FIGURE-", *(f"{k}={v}" for k, v in group))) + else: + parts.append("figure") + + return fig_dir / sanitize_filename("-".join(parts) + ".png") + + +class DummyExecutor: + map = map + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_traceback: TracebackType | None, + ) -> None: + return None + + +def _plot_fig( + fig_dir: Path, + fig_group_data: tuple[tuple[tuple[str, str], ...], list[dict[str, object]]], + row_by: list[str], + col_by: list[str], + curve_by: list[str], + *, + var_x: str, + var_y: str, + filter_by: PlotFilters, + bin_by: PlotBinners, + scale_x: str | None, + scale_y: str | None, + dry_run: bool, +): + fig_group, fig_data = fig_group_data + + row_groups = full_groupby( + fig_data, + key=lambda item: _get_group(item, row_by), + ) + num_rows = len(row_groups) + num_cols = max( + len(full_groupby(row_data, key=lambda item: _get_group(item, col_by))) + for _, row_data in row_groups + ) + + fig_path = _get_fig_path(fig_dir, fig_group) + + print("[BEGIN FIGURE]") + print(f"Group: {dict(fig_group)}") + print(f"Grid: {num_rows} rows x {num_cols} cols") + print(f"Output file: {fig_path}") + + if dry_run: + print("[END FIGURE]") + return + + df = pd.DataFrame.from_records(fig_data) + + if var_x not in df.columns: + raise ValueError( + f"Cannot find {var_x=!r} in parameter sweep results. " + f"Available variables: {df.columns.tolist()}" + ) + if var_y not in df.columns: + raise ValueError( + f"Cannot find {var_y=!r} in parameter sweep results. " + f"Available variables: {df.columns.tolist()}" + ) + for k in row_by: + if k not in df.columns: + raise ValueError( + f"Cannot find row_by={k!r} in parameter sweep results. " + f"Available variables: {df.columns.tolist()}" + ) + for k in col_by: + if k not in df.columns: + raise ValueError( + f"Cannot find col_by={k!r} in parameter sweep results. " + f"Available variables: {df.columns.tolist()}" + ) + for k in curve_by: + if k not in df.columns: + raise ValueError( + f"Cannot find curve_by={k!r} in parameter sweep results. " + f"Available variables: {df.columns.tolist()}" + ) + + df = filter_by.apply(df) + df = bin_by.apply(df) + + df["row_group"] = ( + pd.concat( + [k + "=" + df[k].astype(str) for k in row_by], + axis=1, + ).agg("\n".join, axis=1) + if row_by + else "(All)" + ) + + df["col_group"] = ( + pd.concat( + [k + "=" + df[k].astype(str) for k in col_by], + axis=1, + ).agg("\n".join, axis=1) + if col_by + else "(All)" + ) + + g = sns.FacetGrid(df, row="row_group", col="col_group") + + if row_by and col_by: + g.set_titles("{row_name}\n{col_name}") + elif row_by: + g.set_titles("{row_name}") + elif col_by: + g.set_titles("{col_name}") + else: + g.set_titles("") + + if scale_x: + g.set(xscale=scale_x) + if scale_y: + g.set(yscale=scale_y) + + if len(curve_by) <= 3: + hue, style, size, *_ = (*curve_by, None, None, None) + + g.map_dataframe( + sns.lineplot, + x=var_x, + y=var_y, + hue=hue, + style=style, + size=size, + markers=True, + ) + + g.add_legend(title=hue) + else: + df["curve_group"] = ( + pd.concat( + [k + "=" + df[k].astype(str) for k in curve_by], + axis=1, + ).agg("\n".join, axis=1) + if curve_by + else "(All)" + ) + + g.map_dataframe( + sns.lineplot, + x=var_x, + y=var_y, + hue="curve_group", + markers=True, + ) + + g.add_legend() + + g.savefig(fig_path) + plt.close(g.figure) + + print("[END FIGURE]") + + +def plot( + output_dir: Path, + fig_dir: Path, + fig_by: list[str], + row_by: list[str], + col_by: list[str], + curve_by: list[str], + *, + var_x: str, + var_y: str, + filter_by: PlotFilters, + bin_by: PlotBinners, + scale_x: str | None, + scale_y: str | None, + dry_run: bool, +): + all_data = [ + run_data + for path in output_dir.rglob("**/summary.json") + for run_data in _json_load_bytes(path) + ] + + if not all_data: + raise ValueError(f"Did not find any parameter sweep results under {output_dir}") + + fig_dir.mkdir(parents=True, exist_ok=True) + + fig_groups = full_groupby( + all_data, + key=lambda item: _get_group(item, fig_by), + ) + + with DummyExecutor() if len(fig_groups) <= 1 else ProcessPoolExecutor() as executor: + # Resolve the iterable to ensure that the workers are run + all( + executor.map( + partial( + _plot_fig, + fig_dir, + row_by=row_by, + col_by=col_by, + curve_by=curve_by, + var_x=var_x, + var_y=var_y, + filter_by=filter_by, + bin_by=bin_by, + scale_x=scale_x, + scale_y=scale_y, + dry_run=dry_run, + ), + fig_groups, + ) + ) + + +def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument( + "OUTPUT_DIR", + type=str, + default="results", + help="The directory containing the results to plot, " + "i.e., the `--output-dir` argument to the parameter sweep script.", + ) + parser.add_argument( + "--fig-dir", + type=str, + default="", + help="The directory to save the figures, relative to `OUTPUT_DIR`. " + "By default, the same directory is used.", + ) + parser.add_argument( + "--fig-by", + type=str, + default="", + help="A comma-separated list of variables, such that a separate figure " + "is created for each combination of these variables.", + ) + parser.add_argument( + "--row-by", + type=str, + default="", + help="A comma-separated list of variables, such that a separate row " + "is created for each combination of these variables.", + ) + parser.add_argument( + "--col-by", + type=str, + default="", + help="A comma-separated list of variables, such that a separate column " + "is created for each combination of these variables.", + ) + parser.add_argument( + "--curve-by", + type=str, + default=None, + help="A comma-separated list of variables, such that a separate curve " + "is created for each combination of these variables.", + ) + parser.add_argument( + "--var-x", + type=str, + default="request_throughput", + help="The variable for the x-axis.", + ) + parser.add_argument( + "--var-y", + type=str, + default="p99_e2el_ms", + help="The variable for the y-axis", + ) + parser.add_argument( + "--filter-by", + type=str, + default="", + help="A comma-separated list of statements indicating values to filter by. " + "This is useful to remove outliers. " + "Example: `max_concurrency<1000,max_num_batched_tokens<=4096` means " + "plot only the points where `max_concurrency` is less than 1000 and " + "`max_num_batched_tokens` is no greater than 4096.", + ) + parser.add_argument( + "--bin-by", + type=str, + default="", + help="A comma-separated list of statements indicating values to bin by. " + "This is useful to avoid plotting points that are too close together. " + "Example: `request_throughput%1` means " + "use a bin size of 1 for the `request_throughput` variable.", + ) + parser.add_argument( + "--scale-x", + type=str, + default=None, + help="The scale to use for the x-axis. " + "Currently only accepts string values such as 'log' and 'sqrt'. " + "See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html", + ) + parser.add_argument( + "--scale-y", + type=str, + default=None, + help="The scale to use for the y-axis. " + "Currently only accepts string values such as 'log' and 'sqrt'. " + "See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="If set, prints the information about each figure to plot, " + "then exits without drawing them.", + ) + + +def main(args: argparse.Namespace): + output_dir = Path(args.OUTPUT_DIR) + if not output_dir.exists(): + raise ValueError(f"No parameter sweep results under {output_dir}") + + curve_by = [] if not args.curve_by else args.curve_by.split(",") + row_by = [] if not args.row_by else args.row_by.split(",") + col_by = [] if not args.col_by else args.col_by.split(",") + fig_by = [] if not args.fig_by else args.fig_by.split(",") + + plot( + output_dir=output_dir, + fig_dir=output_dir / args.fig_dir, + fig_by=fig_by, + row_by=row_by, + col_by=col_by, + curve_by=curve_by, + var_x=args.var_x, + var_y=args.var_y, + filter_by=PlotFilters.parse_str(args.filter_by), + bin_by=PlotBinners.parse_str(args.bin_by), + scale_x=args.scale_x, + scale_y=args.scale_y, + dry_run=args.dry_run, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Plot performance curves from parameter sweep results." + ) + add_cli_args(parser) + + main(parser.parse_args()) diff --git a/vllm/benchmarks/sweep/serve.py b/vllm/benchmarks/sweep/serve.py new file mode 100644 index 000000000000..a06d4d6d6098 --- /dev/null +++ b/vllm/benchmarks/sweep/serve.py @@ -0,0 +1,409 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import contextlib +import json +import shlex +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path + +import pandas as pd + +from .param_sweep import ParameterSweep, ParameterSweepItem +from .server import ServerProcess +from .utils import sanitize_filename + + +@contextlib.contextmanager +def run_server( + serve_cmd: list[str], + after_bench_cmd: list[str], + *, + show_stdout: bool, + serve_overrides: ParameterSweepItem, + dry_run: bool, +): + server_cmd = serve_overrides.apply_to_cmd(serve_cmd) + + print("[BEGIN SERVER]") + print(f"Server overrides: {serve_overrides}") + print(f"Server command: {server_cmd}") + + if dry_run: + yield None + print("[END SERVER]") + return + + with ServerProcess(server_cmd, after_bench_cmd, show_stdout=show_stdout) as server: + yield server + + print("[END SERVER]") + + +def _update_run_data( + run_data: dict[str, object], + serve_overrides: ParameterSweepItem, + bench_overrides: ParameterSweepItem, + run_number: int, +): + run_data["run_number"] = run_number + run_data.update(serve_overrides) + run_data.update(bench_overrides) + + return run_data + + +def run_benchmark( + server: ServerProcess | None, + bench_cmd: list[str], + *, + serve_overrides: ParameterSweepItem, + bench_overrides: ParameterSweepItem, + run_number: int, + output_path: Path, + dry_run: bool, +): + benchmark_cmd = [ + *bench_overrides.apply_to_cmd(bench_cmd), + "--percentile-metrics", + "ttft,tpot,itl,e2el", + "--save-result", + "--result-dir", + str(output_path.parent), + "--result-filename", + output_path.name, + ] + + print("[BEGIN BENCHMARK]") + print(f"Benchmark overrides: {bench_overrides}") + print(f"Run Number: {run_number}") + print(f"Benchmark command: {benchmark_cmd}") + print(f"Output file: {output_path}") + + run_data: dict[str, object] + + if output_path.exists(): + print("Found existing results. Skipping.") + + with output_path.open("rb") as f: + run_data = json.load(f) + return _update_run_data( + run_data, + serve_overrides, + bench_overrides, + run_number, + ) + + if server is None: + if not dry_run: + raise ValueError(f"Cannot find results at {output_path}") + + print("[END BENCHMARK]") + return None + + output_path.parent.mkdir(parents=True, exist_ok=True) + + server.run_subcommand(benchmark_cmd) + server.after_bench() + + with output_path.open("rb") as f: + run_data = json.load(f) + + run_data = _update_run_data( + run_data, + serve_overrides, + bench_overrides, + run_number, + ) + + with output_path.open("w") as f: + json.dump(run_data, f, indent=4) + + print("[END BENCHMARK]") + + return run_data + + +def _get_comb_base_path( + output_dir: Path, + serve_comb: ParameterSweepItem, + bench_comb: ParameterSweepItem, +): + parts = list[str]() + if serve_comb: + parts.extend(("SERVE-", serve_comb.as_text(sep="-"))) + if bench_comb: + parts.extend(("BENCH-", bench_comb.as_text(sep="-"))) + + return output_dir / sanitize_filename("-".join(parts)) + + +def _get_comb_run_path(base_path: Path, run_number: int | None): + if run_number is None: + return base_path / "summary.json" + + return base_path / f"run={run_number}.json" + + +def _comb_needs_server( + serve_comb: ParameterSweepItem, + bench_combs: ParameterSweep, + output_dir: Path, +): + for bench_comb in bench_combs: + base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb) + if not _get_comb_run_path(base_path, run_number=None).exists(): + return True + + return False + + +def run_comb( + server: ServerProcess | None, + bench_cmd: list[str], + *, + serve_comb: ParameterSweepItem, + bench_comb: ParameterSweepItem, + base_path: Path, + num_runs: int, + dry_run: bool, +): + comb_data = list[dict[str, object]]() + + for run_number in range(num_runs): + run_data = run_benchmark( + server, + bench_cmd, + serve_overrides=serve_comb, + bench_overrides=bench_comb, + run_number=run_number, + output_path=_get_comb_run_path(base_path, run_number), + dry_run=dry_run, + ) + + if run_data is not None: + comb_data.append(run_data) + + if dry_run: + return None + + with _get_comb_run_path(base_path, run_number=None).open("w") as f: + json.dump(comb_data, f, indent=4) + + return comb_data + + +def run_combs( + serve_cmd: list[str], + bench_cmd: list[str], + after_bench_cmd: list[str], + *, + show_stdout: bool, + serve_params: ParameterSweep, + bench_params: ParameterSweep, + output_dir: Path, + num_runs: int, + dry_run: bool, +): + all_data = list[dict[str, object]]() + for serve_comb in serve_params: + with ( + run_server( + serve_cmd, + after_bench_cmd, + show_stdout=show_stdout, + serve_overrides=serve_comb, + dry_run=dry_run, + ) + if _comb_needs_server(serve_comb, bench_params, output_dir) + else contextlib.nullcontext() + ) as server: + for bench_comb in bench_params: + base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb) + + comb_data = run_comb( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb, + base_path=base_path, + num_runs=num_runs, + dry_run=dry_run, + ) + + if comb_data is not None: + all_data.extend(comb_data) + + if dry_run: + return None + + combined_df = pd.DataFrame.from_records(all_data) + combined_df.to_csv(output_dir / "summary.csv") + + return combined_df + + +@dataclass +class SweepServeArgs: + serve_cmd: list[str] + bench_cmd: list[str] + after_bench_cmd: list[str] + show_stdout: bool + serve_params: ParameterSweep + bench_params: ParameterSweep + output_dir: Path + num_runs: int + dry_run: bool + resume: str | None + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + serve_cmd = shlex.split(args.serve_cmd) + bench_cmd = shlex.split(args.bench_cmd) + after_bench_cmd = ( + [] if args.after_bench_cmd is None else shlex.split(args.after_bench_cmd) + ) + + if args.serve_params: + serve_params = ParameterSweep.read_json(args.serve_params) + else: + # i.e.: run serve_cmd without any modification + serve_params = ParameterSweep.from_records([{}]) + + if args.bench_params: + bench_params = ParameterSweep.read_json(args.bench_params) + else: + # i.e.: run bench_cmd without any modification + bench_params = ParameterSweep.from_records([{}]) + + num_runs = args.num_runs + if num_runs < 1: + raise ValueError("`num_runs` should be at least 1.") + + return cls( + serve_cmd=serve_cmd, + bench_cmd=bench_cmd, + after_bench_cmd=after_bench_cmd, + show_stdout=args.show_stdout, + serve_params=serve_params, + bench_params=bench_params, + output_dir=Path(args.output_dir), + num_runs=num_runs, + dry_run=args.dry_run, + resume=args.resume, + ) + + @classmethod + def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser.add_argument( + "--serve-cmd", + type=str, + required=True, + help="The command used to run the server: `vllm serve ...`", + ) + parser.add_argument( + "--bench-cmd", + type=str, + required=True, + help="The command used to run the benchmark: `vllm bench serve ...`", + ) + parser.add_argument( + "--after-bench-cmd", + type=str, + default=None, + help="After a benchmark run is complete, invoke this command instead of " + "the default `ServerWrapper.clear_cache()`.", + ) + parser.add_argument( + "--show-stdout", + action="store_true", + help="If set, logs the standard output of subcommands. " + "Useful for debugging but can be quite spammy.", + ) + parser.add_argument( + "--serve-params", + type=str, + default=None, + help="Path to JSON file containing a list of parameter combinations " + "for the `vllm serve` command. " + "If both `serve_params` and `bench_params` are given, " + "this script will iterate over their Cartesian product.", + ) + parser.add_argument( + "--bench-params", + type=str, + default=None, + help="Path to JSON file containing a list of parameter combinations " + "for the `vllm bench serve` command. " + "If both `serve_params` and `bench_params` are given, " + "this script will iterate over their Cartesian product.", + ) + parser.add_argument( + "-o", + "--output-dir", + type=str, + default="results", + help="The directory to which results are written.", + ) + parser.add_argument( + "--num-runs", + type=int, + default=3, + help="Number of runs per parameter combination.", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="If set, prints the commands to run, " + "then exits without executing them.", + ) + parser.add_argument( + "--resume", + type=str, + default=None, + help="Set this to the name of a directory under `output_dir` (which is a " + "timestamp) to resume a previous execution of this script, i.e., only run " + "parameter combinations for which there are still no output files.", + ) + + return parser + + +def run_main(args: SweepServeArgs): + timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = args.output_dir / timestamp + + if args.resume and not output_dir.exists(): + raise ValueError(f"Cannot resume from non-existent directory ({output_dir})") + + try: + return run_combs( + serve_cmd=args.serve_cmd, + bench_cmd=args.bench_cmd, + after_bench_cmd=args.after_bench_cmd, + show_stdout=args.show_stdout, + serve_params=args.serve_params, + bench_params=args.bench_params, + output_dir=output_dir, + num_runs=args.num_runs, + dry_run=args.dry_run, + ) + except BaseException as exc: + raise RuntimeError( + f"The script was terminated early. Use `--resume {timestamp}` " + f"to continue the script from its last checkpoint." + ) from exc + + +def main(args: argparse.Namespace): + run_main(SweepServeArgs.from_cli_args(args)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run vLLM server benchmark under multiple settings." + ) + SweepServeArgs.add_cli_args(parser) + + main(parser.parse_args()) diff --git a/vllm/benchmarks/sweep/serve_sla.py b/vllm/benchmarks/sweep/serve_sla.py new file mode 100644 index 000000000000..6159aba4bbb5 --- /dev/null +++ b/vllm/benchmarks/sweep/serve_sla.py @@ -0,0 +1,484 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import contextlib +import json +import math +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Literal, get_args + +import pandas as pd +from typing_extensions import assert_never + +from .param_sweep import ParameterSweep, ParameterSweepItem +from .serve import SweepServeArgs, run_benchmark, run_server +from .server import ServerProcess +from .sla_sweep import SLASweep, SLASweepItem +from .utils import sanitize_filename + + +def _get_sla_base_path( + output_dir: Path, + serve_comb: ParameterSweepItem, + bench_comb: ParameterSweepItem, +): + parts = list[str]() + if serve_comb: + parts.extend(("SERVE-", serve_comb.as_text(sep="-"))) + if bench_comb: + parts.extend(("BENCH-", bench_comb.as_text(sep="-"))) + + return output_dir / sanitize_filename("-".join(parts)) + + +def _get_sla_iter_path( + base_path: Path, + sla_comb: SLASweepItem, + sla_variable: str, + sla_value: int | None, +): + if sla_value is None: + prefix = sla_comb.as_text(sep="-") + return base_path / f"SLA--{prefix}.json" + + return base_path / f"{sla_variable}={sla_value}" + + +def _get_sla_run_path(iter_path: Path, run_number: int | None): + if run_number is None: + return iter_path / "summary.json" + + return iter_path / f"run={run_number}.json" + + +def _sla_needs_server( + serve_comb: ParameterSweepItem, + bench_combs: ParameterSweep, + sla_combs: SLASweep, + sla_variable: str, + output_dir: Path, +): + for bench_comb in bench_combs: + base_path = _get_sla_base_path(output_dir, serve_comb, bench_comb) + for sla_comb in sla_combs: + if not _get_sla_iter_path( + base_path, + sla_comb, + sla_variable, + sla_value=None, + ).exists(): + return True + + return False + + +def run_sla( + server: ServerProcess | None, + bench_cmd: list[str], + *, + serve_comb: ParameterSweepItem, + bench_comb: ParameterSweepItem, + iter_path: Path, + num_runs: int, + dry_run: bool, +): + iter_data = list[dict[str, object]]() + + for run_number in range(num_runs): + run_data = run_benchmark( + server, + bench_cmd, + serve_overrides=serve_comb, + bench_overrides=bench_comb, + run_number=run_number, + output_path=_get_sla_run_path(iter_path, run_number), + dry_run=dry_run, + ) + + if run_data is not None: + iter_data.append(run_data) + + if dry_run: + return None + + with _get_sla_run_path(iter_path, run_number=None).open("w") as f: + json.dump(iter_data, f, indent=4) + + return iter_data + + +SLAVariable = Literal["request_rate", "max_concurrency"] + + +def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable): + request_throughput = float(run_data["request_throughput"]) # type: ignore + if sla_variable == "request_rate": + return request_throughput + if sla_variable == "max_concurrency": + mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore + return request_throughput * mean_latency_ms / 1000 + + assert_never(sla_variable) + + +def _estimate_sla_bounds( + server: ServerProcess | None, + bench_cmd: list[str], + *, + serve_comb: ParameterSweepItem, + bench_comb: ParameterSweepItem, + sla_comb: SLASweepItem, + base_path: Path, + num_runs: int, + dry_run: bool, + sla_variable: SLAVariable, + init_value: int, + max_value: int, +): + sla_data = list[dict[str, object]]() + + max_passing: int = 0 + min_failing: int = 0 + + val: int = init_value + assert val > 0 + + while True: + print(f"Testing {sla_variable}: {val} req/s") + + iter_data = run_sla( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb | {sla_variable: val}, + iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val), + num_runs=num_runs, + dry_run=dry_run, + ) + + assert iter_data is not None + sla_data.extend(iter_data) + + iter_data_mean = { + k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore + for k in sla_comb + } + + sla_results = [ + criterion.print_and_validate(iter_data_mean, k) + for k, criterion in sla_comb.items() + ] + + if all(sla_results): + print("SLA criteria are met.") + max_passing = val + val *= 2 + else: + print("SLA criteria are not met.") + min_failing = val + break + + if val >= max_value: + break + + return sla_data, (max_passing, min_failing) + + +def _find_sla_value( + server: ServerProcess | None, + bench_cmd: list[str], + *, + serve_comb: ParameterSweepItem, + bench_comb: ParameterSweepItem, + sla_comb: SLASweepItem, + base_path: Path, + num_runs: int, + dry_run: bool, + sla_variable: SLAVariable, + min_value: int, + max_value: int, +): + sla_data = list[dict[str, object]]() + + left: int = min_value + right: int = max_value + + while True: + val = (left + right) // 2 + print(f"Testing {sla_variable}: {val} req/s") + + iter_data = run_sla( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb | {sla_variable: val}, + iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, val), + num_runs=num_runs, + dry_run=dry_run, + ) + + assert iter_data is not None + sla_data.extend(iter_data) + + iter_data_mean = { + k: sum(float(run_data[k]) for run_data in iter_data) / len(iter_data) # type: ignore + for k in sla_comb + } + + sla_results = [ + criterion.print_and_validate(iter_data_mean, k) + for k, criterion in sla_comb.items() + ] + + if all(sla_results): + print("SLA criteria are met.") + left = val + else: + print("SLA criteria are not met.") + right = val + + if right - left <= 1: + break + + return sla_data, left + + +def search_sla( + server: ServerProcess | None, + bench_cmd: list[str], + *, + serve_comb: ParameterSweepItem, + bench_comb: ParameterSweepItem, + sla_comb: SLASweepItem, + sla_variable: SLAVariable, + sla_inf_value: int = 65536, # The value that represents infinite QPS + base_path: Path, + num_runs: int, + dry_run: bool, +): + print("[SLA START]") + print(f"SLA criteria: {sla_comb.as_text()}") + + sla_data_0 = run_sla( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb | {sla_variable: sla_inf_value}, + iter_path=_get_sla_iter_path(base_path, sla_comb, sla_variable, sla_inf_value), + num_runs=num_runs, + dry_run=dry_run, + ) + if sla_data_0 is None: + assert dry_run + print("Omitting SLA search.") + print("[SLA END]") + return None + + sla_init_value = math.ceil( + sum(_estimate_sla_value(item, sla_variable) for item in sla_data_0) + / len(sla_data_0) + ) + print(f"Initial {sla_variable} to search: {sla_init_value} req/s.") + + sla_data_1, (sla_min, sla_max) = _estimate_sla_bounds( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb, + sla_comb=sla_comb, + base_path=base_path, + num_runs=num_runs, + dry_run=dry_run, + sla_variable=sla_variable, + init_value=sla_init_value, + max_value=sla_inf_value, + ) + print(f"Range of {sla_variable} to search: [{sla_min}, {sla_max}] req/s.") + + sla_data_2, sla_value = _find_sla_value( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb, + sla_comb=sla_comb, + base_path=base_path, + num_runs=num_runs, + dry_run=dry_run, + sla_variable=sla_variable, + min_value=sla_min, + max_value=sla_max, + ) + + sla_data = sla_data_0 + sla_data_1 + sla_data_2 + print(f"Maximum {sla_variable} for SLA: {sla_value} req/s.") + + with _get_sla_iter_path( + base_path, + sla_comb, + sla_variable, + sla_value=None, + ).open("w") as f: + json.dump(sla_data, f, indent=4) + + print("[SLA END]") + + return sla_data + + +def run_slas( + serve_cmd: list[str], + bench_cmd: list[str], + after_bench_cmd: list[str], + *, + show_stdout: bool, + serve_params: ParameterSweep, + bench_params: ParameterSweep, + sla_params: SLASweep, + sla_variable: SLAVariable, + output_dir: Path, + num_runs: int, + dry_run: bool, +): + if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params): + raise ValueError( + f"You should not override `{sla_variable}` in `bench_params` in SLA mode, " + "since it is supposed to be determined automatically." + ) + + all_data = list[dict[str, object]]() + for serve_comb in serve_params: + with ( + run_server( + serve_cmd, + after_bench_cmd, + show_stdout=show_stdout, + serve_overrides=serve_comb, + dry_run=dry_run, + ) + if _sla_needs_server( + serve_comb, + bench_params, + sla_params, + sla_variable, + output_dir, + ) + else contextlib.nullcontext() + ) as server: + for bench_comb in bench_params: + for sla_comb in sla_params: + base_path = _get_sla_base_path(output_dir, serve_comb, bench_comb) + + comb_data = search_sla( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb, + sla_comb=sla_comb, + sla_variable=sla_variable, + base_path=base_path, + num_runs=num_runs, + dry_run=dry_run, + ) + + if comb_data is not None: + all_data.extend(comb_data) + + if dry_run: + return None + + combined_df = pd.DataFrame.from_records(all_data) + combined_df.to_csv(output_dir / "summary.csv") + + return combined_df + + +@dataclass +class SweepServeSLAArgs(SweepServeArgs): + sla_params: SLASweep + sla_variable: SLAVariable + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + # NOTE: Don't use super() as `from_cli_args` calls `cls()` + base_args = SweepServeArgs.from_cli_args(args) + + if args.sla_params: + sla_params = SLASweep.read_json(args.sla_params) + else: + sla_params = SLASweep.from_records([]) + + return cls( + **asdict(base_args), + sla_params=sla_params, + sla_variable=args.sla_variable, + ) + + @classmethod + def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser = super().add_cli_args(parser) + + parser.add_argument( + "--sla-params", + type=str, + required=True, + help="Path to JSON file containing a list of SLA constraints to satisfy. " + 'Each constraint is expressed in `{"<KEY>": "<OP><VALUE>"}` format, ' + 'e.g.: `{"p99_e2el_ms": "<=500"}` means that ' + "the E2E latency should be less than 500ms 99%% of the time. " + "Setting this option runs this script in SLA mode, which searches for " + "the maximum `sla_variable` that satisfies the constraints for " + "each combination of `serve_params`, `bench_params`, and `sla_params`.", + ) + parser.add_argument( + "--sla-variable", + type=str, + choices=get_args(SLAVariable), + default="request_rate", + help="Whether to tune request rate or maximum concurrency to satisfy " + "the SLA constraints.", + ) + + return parser + + +def run_main(args: SweepServeSLAArgs): + timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = args.output_dir / timestamp + + if args.resume and not output_dir.exists(): + raise ValueError(f"Cannot resume from non-existent directory ({output_dir})") + + try: + return run_slas( + serve_cmd=args.serve_cmd, + bench_cmd=args.bench_cmd, + after_bench_cmd=args.after_bench_cmd, + show_stdout=args.show_stdout, + serve_params=args.serve_params, + bench_params=args.bench_params, + sla_params=args.sla_params, + sla_variable=args.sla_variable, + output_dir=output_dir, + num_runs=args.num_runs, + dry_run=args.dry_run, + ) + except BaseException as exc: + raise RuntimeError( + f"The script was terminated early. Use `--resume {timestamp}` " + f"to continue the script from its last checkpoint." + ) from exc + + +def main(args: argparse.Namespace): + run_main(SweepServeSLAArgs.from_cli_args(args)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Tune a variable to meet SLAs under multiple settings." + ) + SweepServeSLAArgs.add_cli_args(parser) + + main(parser.parse_args()) diff --git a/vllm/benchmarks/sweep/server.py b/vllm/benchmarks/sweep/server.py new file mode 100644 index 000000000000..f17578726415 --- /dev/null +++ b/vllm/benchmarks/sweep/server.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import os +import signal +import subprocess +from types import TracebackType + +import requests +from typing_extensions import Self + + +class ServerProcess: + def __init__( + self, + server_cmd: list[str], + after_bench_cmd: list[str], + *, + show_stdout: bool, + ) -> None: + super().__init__() + + self.server_cmd = server_cmd + self.after_bench_cmd = after_bench_cmd + self.show_stdout = show_stdout + + def __enter__(self) -> Self: + self.start() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_traceback: TracebackType | None, + ) -> None: + self.stop() + + def start(self): + # Create new process for clean termination + self._server_process = subprocess.Popen( + self.server_cmd, + start_new_session=True, + stdout=None if self.show_stdout else subprocess.DEVNULL, + # Need `VLLM_SERVER_DEV_MODE=1` for `_reset_caches` + env=os.environ | {"VLLM_SERVER_DEV_MODE": "1"}, + ) + + def stop(self): + server_process = self._server_process + + if server_process.poll() is None: + # In case only some processes have been terminated + with contextlib.suppress(ProcessLookupError): + # We need to kill both API Server and Engine processes + os.killpg(os.getpgid(server_process.pid), signal.SIGKILL) + + def run_subcommand(self, cmd: list[str]): + return subprocess.run( + cmd, + stdout=None if self.show_stdout else subprocess.DEVNULL, + check=True, + ) + + def after_bench(self) -> None: + if not self.after_bench_cmd: + self.reset_caches() + return + + self.run_subcommand(self.after_bench_cmd) + + def _get_vllm_server_address(self) -> str: + server_cmd = self.server_cmd + + for host_key in ("--host",): + if host_key in server_cmd: + host = server_cmd[server_cmd.index(host_key) + 1] + break + else: + host = "localhost" + + for port_key in ("-p", "--port"): + if port_key in server_cmd: + port = int(server_cmd[server_cmd.index(port_key) + 1]) + break + else: + port = 8000 # The default value in vllm serve + + return f"http://{host}:{port}" + + def reset_caches(self) -> None: + server_cmd = self.server_cmd + + # Use `.endswith()` to match `/bin/...` + if server_cmd[0].endswith("vllm"): + server_address = self._get_vllm_server_address() + print(f"Resetting caches at {server_address}") + + res = requests.post(f"{server_address}/reset_prefix_cache") + res.raise_for_status() + + res = requests.post(f"{server_address}/reset_mm_cache") + res.raise_for_status() + elif server_cmd[0].endswith("infinity_emb"): + if "--vector-disk-cache" in server_cmd: + raise NotImplementedError( + "Infinity server uses caching but does not expose a method " + "to reset the cache" + ) + else: + raise NotImplementedError( + f"No implementation of `reset_caches` for `{server_cmd[0]}` server. " + "Please specify a custom command via `--after-bench-cmd`." + ) diff --git a/vllm/benchmarks/sweep/sla_sweep.py b/vllm/benchmarks/sweep/sla_sweep.py new file mode 100644 index 000000000000..327e3c7c5897 --- /dev/null +++ b/vllm/benchmarks/sweep/sla_sweep.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from typing_extensions import override + + +@dataclass +class SLACriterionBase(ABC): + target: float + + @abstractmethod + def validate(self, actual: float) -> bool: + """Return `True` if this criterion is met; otherwise `False`.""" + raise NotImplementedError + + @abstractmethod + def format_cond(self, lhs: str) -> str: + raise NotImplementedError + + def print_and_validate( + self, + metrics: dict[str, float], + metrics_key: str, + ) -> bool: + metric = metrics[metrics_key] + result = self.validate(metric) + + cond = self.format_cond(f"{metrics_key} = {metric:.2f}") + print(f"Validating SLA: {cond} | " + ("PASSED" if result else "FAILED")) + + return result + + +@dataclass +class SLALessThan(SLACriterionBase): + @override + def validate(self, actual: float) -> bool: + return actual < self.target + + @override + def format_cond(self, lhs: str) -> str: + return f"{lhs}<{self.target:.2f}" + + +@dataclass +class SLALessThanOrEqualTo(SLACriterionBase): + @override + def validate(self, actual: float) -> bool: + return actual <= self.target + + @override + def format_cond(self, lhs: str) -> str: + return f"{lhs}<={self.target:.2f}" + + +@dataclass +class SLAGreaterThan(SLACriterionBase): + @override + def validate(self, actual: float) -> bool: + return actual > self.target + + @override + def format_cond(self, lhs: str) -> str: + return f"{lhs}>{self.target:.2f}" + + +@dataclass +class SLAGreaterThanOrEqualTo(SLACriterionBase): + @override + def validate(self, actual: float) -> bool: + return actual >= self.target + + @override + def format_cond(self, lhs: str) -> str: + return f"{lhs}>={self.target:.2f}" + + +# NOTE: The ordering is important! Match longer op_keys first +SLA_CRITERIA: dict[str, type[SLACriterionBase]] = { + "<=": SLALessThanOrEqualTo, + ">=": SLAGreaterThanOrEqualTo, + "<": SLALessThan, + ">": SLAGreaterThan, +} + + +class SLASweep(list["SLASweepItem"]): + @classmethod + def read_json(cls, filepath: os.PathLike): + with open(filepath, "rb") as f: + records = json.load(f) + + return cls.from_records(records) + + @classmethod + def from_records(cls, records: list[dict[str, str]]): + if not isinstance(records, list): + raise TypeError( + f"The SLA sweep should be a list of dictionaries, " + f"but found type: {type(records)}" + ) + + return cls(SLASweepItem.from_record(record) for record in records) + + +class SLASweepItem(dict[str, SLACriterionBase]): + @classmethod + def from_record(cls, record: dict[str, str]): + sla_criteria: dict[str, SLACriterionBase] = {} + + for metric_key, metric_value in record.items(): + for op_key in SLA_CRITERIA: + if metric_value.startswith(op_key): + sla_criteria[metric_key] = SLA_CRITERIA[op_key]( + float(metric_value.removeprefix(op_key)) + ) + break + else: + raise ValueError( + f"Invalid operator for " + f"SLA constraint '{metric_key}={metric_value}'. " + f"Valid operators are: {sorted(SLA_CRITERIA)}", + ) + + return cls(sla_criteria) + + def as_text(self, sep: str = ", ") -> str: + return sep.join(v.format_cond(k) for k, v in self.items()) diff --git a/vllm/benchmarks/sweep/utils.py b/vllm/benchmarks/sweep/utils.py new file mode 100644 index 000000000000..49d7867eaf48 --- /dev/null +++ b/vllm/benchmarks/sweep/utils.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +def sanitize_filename(filename: str) -> str: + return filename.replace("/", "_").replace("..", "__").strip("'").strip('"') diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index 04bc29b07aac..866365ac18eb 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -9,7 +9,7 @@ import random import time import warnings -from typing import Any, Optional, Union +from typing import Any import torch import uvloop @@ -34,7 +34,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams -from vllm.utils import merge_async_iterators +from vllm.utils.async_utils import merge_async_iterators def run_vllm( @@ -43,7 +43,7 @@ def run_vllm( engine_args: EngineArgs, do_profile: bool, disable_detokenize: bool = False, -) -> tuple[float, Optional[list[RequestOutput]]]: +) -> tuple[float, list[RequestOutput] | None]: from vllm import LLM, SamplingParams llm = LLM(**dataclasses.asdict(engine_args)) @@ -56,19 +56,19 @@ def run_vllm( " prompt_len and expected_output_len for all requests." ) # Add the requests to the engine. - prompts: list[Union[TextPrompt, TokensPrompt]] = [] + prompts: list[TextPrompt | TokensPrompt] = [] sampling_params: list[SamplingParams] = [] for request in requests: - prompts.append( - TokensPrompt( - prompt_token_ids=request.prompt["prompt_token_ids"], - multi_modal_data=request.multi_modal_data, - ) + prompt = ( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"]) if "prompt_token_ids" in request.prompt - else TextPrompt( - prompt=request.prompt, multi_modal_data=request.multi_modal_data - ) + else TextPrompt(prompt=request.prompt) ) + if request.multi_modal_data: + assert isinstance(request.multi_modal_data, dict) + prompt["multi_modal_data"] = request.multi_modal_data + prompts.append(prompt) + sampling_params.append( SamplingParams( n=n, @@ -79,7 +79,7 @@ def run_vllm( detokenize=not disable_detokenize, ) ) - lora_requests: Optional[list[LoRARequest]] = None + lora_requests: list[LoRARequest] | None = None if engine_args.enable_lora: lora_requests = [request.lora_request for request in requests] @@ -186,7 +186,7 @@ async def run_vllm_async( engine_args, disable_frontend_multiprocessing=disable_frontend_multiprocessing, ) as llm: - model_config = await llm.get_model_config() + model_config = llm.model_config assert all( model_config.max_model_len >= (request.prompt_len + request.expected_output_len) @@ -197,9 +197,9 @@ async def run_vllm_async( ) # Add the requests to the engine. - prompts: list[Union[TextPrompt, TokensPrompt]] = [] + prompts: list[TextPrompt | TokensPrompt] = [] sampling_params: list[SamplingParams] = [] - lora_requests: list[Optional[LoRARequest]] = [] + lora_requests: list[LoRARequest | None] = [] for request in requests: prompt = ( TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"]) @@ -251,7 +251,7 @@ def run_hf( disable_detokenize: bool = False, ) -> float: llm = AutoModelForCausalLM.from_pretrained( - model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code + model, dtype=torch.float16, trust_remote_code=trust_remote_code ) if llm.config.model_type == "llama": # To enable padding in the HF backend. @@ -696,7 +696,7 @@ def main(args: argparse.Namespace): ) requests = get_requests(args, tokenizer) is_multi_modal = any(request.multi_modal_data is not None for request in requests) - request_outputs: Optional[list[RequestOutput]] = None + request_outputs: list[RequestOutput] | None = None if args.backend == "vllm": if args.async_engine: elapsed_time = uvloop.run( diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index 7448bb122152..b5fd67c5b027 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -18,12 +18,12 @@ QuantKey, kFp8StaticTensorSym, kNvfp4Quant, - kStaticTensorScale, ) from vllm.platforms import current_platform from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -66,6 +66,8 @@ def __init__( ) self.FUSED_OP = FUSED_OPS[self.quant_key] + self.silu_and_mul_matcher = MatcherSiluAndMul() + def empty_quant(self, *args, **kwargs): kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs} return torch.empty(*args, **kwargs) @@ -80,42 +82,38 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): Fusion for SiluMul+Fp8StaticQuant Pattern """ - def __init__(self, symmetric: bool = True): - quant_key = QuantKey( - dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric - ) - super().__init__(quant_key) + def __init__(self): + super().__init__(kFp8StaticTensorSym) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def register(self, pm_pass: PatternMatcherPass): def pattern( - result: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, scale: torch.Tensor, ): - at1 = auto_functionalized(SILU_MUL_OP, result=result_silu_mul, input=input) - at2 = auto_functionalized( - self.QUANT_OP, result=result, input=at1[1], scale=scale - ) - return at2[1] + result_silu_mul = self.silu_and_mul_matcher(input) + result_quant = self.quant_matcher(result_silu_mul, scale) + return result_quant[0] def replacement( - result: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, scale: torch.Tensor, ): + d = input.shape[-1] // 2 + output_shape = input.shape[:-1] + (d,) + result = torch.empty( + output_shape, device=input.device, dtype=self.quant_dtype + ) at = auto_functionalized( self.FUSED_OP, result=result, input=input, scale=scale ) return at[1] inputs = [ - self.empty_quant(5, 4), # result - empty_bf16(5, 4), # result_silu_mul - empty_bf16(5, 4), # input - empty_fp32(1, 1), # scale + *self.silu_and_mul_matcher.inputs(), # input + self.quant_matcher.inputs()[1], # scale ] + pattern(*inputs) register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) @@ -132,24 +130,22 @@ def register(self, pm_pass: PatternMatcherPass): def pattern( result: torch.Tensor, output_scale: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, scale: torch.Tensor, ): - at1 = auto_functionalized(SILU_MUL_OP, result=result_silu_mul, input=input) - at2 = auto_functionalized( + result_silu_mul = self.silu_and_mul_matcher(input) + at = auto_functionalized( self.QUANT_OP, output=result, - input=at1[1], + input=result_silu_mul, output_scale=output_scale, input_scale=scale, ) - return at2[1], at2[2] + return at[1], at[2] def replacement( result: torch.Tensor, output_scale: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, scale: torch.Tensor, ): @@ -165,7 +161,6 @@ def replacement( inputs = [ self.empty_quant(5, 32), # result empty_i32(128, 4), # output_scale - empty_bf16(5, 64), # result_silu_mul empty_bf16(5, 64), # input empty_fp32(1, 1), # scale ] diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 55bd3d0c60b1..53fd5e74dc0a 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -3,23 +3,31 @@ import ast import dataclasses +import hashlib import os import pprint import time -from collections.abc import Sequence +from collections.abc import Callable, Sequence from contextlib import contextmanager -from typing import Any, Callable, Optional +from typing import Any import torch import torch.fx as fx from torch._dispatch.python import enable_python_dispatcher import vllm.envs as envs +from vllm.compilation.inductor_pass import pass_context +from vllm.compilation.partition_rules import ( + inductor_partition_rule_context, + resolve_defined_ops, +) from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname +from vllm.utils.import_utils import resolve_obj_by_qualname +from vllm.utils.torch_utils import is_torch_equal_or_newer +from .caching import VllmSerializableFunction from .compiler_interface import ( CompilerInterface, EagerAdaptor, @@ -49,7 +57,7 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: return InductorAdaptor() else: assert compilation_config.backend == "eager", ( - "Custom backends not supported with CompilationLevel.PIECEWISE" + "Custom backends not supported with CompilationMode.VLLM_COMPILE" ) logger.debug("Using EagerAdaptor") @@ -72,7 +80,7 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig): - self.cache: dict[tuple[Optional[int], int, str], Any] = dict() + self.cache: dict[tuple[int | None, int, str], Any] = dict() self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -80,6 +88,21 @@ def __init__(self, compilation_config: CompilationConfig): def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) + @contextmanager + def compile_context(self, runtime_shape: int | None = None): + """Provide compilation context for the duration of compilation to set + any torch global properties we want to scope to a single Inductor + compilation (e.g. partition rules, pass context).""" + with pass_context(runtime_shape): + if self.compilation_config.use_inductor_graph_partition: + inductor_partition_ops = resolve_defined_ops( + self.compilation_config.splitting_ops + ) + with inductor_partition_rule_context(inductor_partition_ops): + yield + else: + yield + def initialize_cache( self, cache_dir: str, disable_cache: bool = False, prefix: str = "" ): @@ -127,8 +150,8 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: Optional[int] = None, - ) -> Optional[Callable]: + runtime_shape: int | None = None, + ) -> Callable | None: if (runtime_shape, graph_index, self.compiler.name) not in self.cache: return None handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] @@ -160,7 +183,7 @@ def compile( compilation_config: CompilationConfig, graph_index: int = 0, num_graphs: int = 1, - runtime_shape: Optional[int] = None, + runtime_shape: int | None = None, ) -> Any: if graph_index == 0: # before compiling the first graph, record the start time @@ -179,6 +202,7 @@ def compile( # there can be multiple graphs due to piecewise compilation. now = time.time() elapsed = now - compilation_start_time + compilation_config.compilation_time += elapsed if runtime_shape is None: logger.info( "Directly load the compiled graph(s) for dynamic shape " @@ -201,9 +225,15 @@ def compile( maybe_key = None else: maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" - compiled_graph, handle = self.compiler.compile( - graph, example_inputs, additional_inductor_config, runtime_shape, maybe_key - ) + + with self.compile_context(runtime_shape): + compiled_graph, handle = self.compiler.compile( + graph, + example_inputs, + additional_inductor_config, + runtime_shape, + maybe_key, + ) assert compiled_graph is not None, "Failed to compile the graph" @@ -215,10 +245,14 @@ def compile( if graph_index == 0: # adds some info logging for the first graph if runtime_shape is None: - logger.info("Cache the graph for dynamic shape for later use") + logger.info_once( + "Cache the graph for dynamic shape for later use", scope="local" + ) else: - logger.info( - "Cache the graph of shape %s for later use", str(runtime_shape) + logger.info_once( + "Cache the graph of shape %s for later use", + str(runtime_shape), + scope="local", ) if runtime_shape is None: logger.debug( @@ -242,12 +276,17 @@ def compile( elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed if runtime_shape is None: - logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed) + logger.info_once( + "Compiling a graph for dynamic shape takes %.2f s", + elapsed, + scope="local", + ) else: - logger.info( + logger.info_once( "Compiling a graph for shape %s takes %.2f s", runtime_shape, elapsed, + scope="local", ) return compiled_graph @@ -262,7 +301,7 @@ class SplitItem: def split_graph( - graph: fx.GraphModule, ops: list[str] + graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload] ) -> tuple[fx.GraphModule, list[SplitItem]]: # split graph by ops subgraph_id = 0 @@ -271,7 +310,12 @@ def split_graph( for node in graph.graph.nodes: if node.op in ("output", "placeholder"): continue - if node.op == "call_function" and str(node.target) in ops: + # Match node.target against resolved_ops + # node.target can be OpOverloadPacket, need to check .default + if node.op == "call_function" and ( + node.target in resolved_ops + or (hasattr(node.target, "default") and node.target.default in resolved_ops) + ): subgraph_id += 1 node_to_subgraph_id[node] = subgraph_id split_op_graphs.append(subgraph_id) @@ -447,7 +491,7 @@ def set_model_tag(tag: str): class VllmBackend: """The compilation backend for `torch.compile` with vLLM. - It is used for compilation level of `CompilationLevel.PIECEWISE`, + It is used for compilation mode of `CompilationMode.VLLM_COMPILE`, where we customize the compilation. The major work of this backend is to split the graph into @@ -522,7 +566,11 @@ def configure_post_pass(self): self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) inductor_config[PASS_KEY] = self.post_grad_pass_manager - def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: + def __call__( + self, graph: fx.GraphModule, example_inputs + ) -> VllmSerializableFunction: + from .caching import _compute_code_hash, compilation_config_hash_factors + vllm_config = self.vllm_config if not self.compilation_config.cache_dir: # no provided cache dir, generate one based on the known factors @@ -530,39 +578,11 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: # the cache dir will be the same so that we can reuse the compiled # graph. - factors = [] - # 0. factors come from the env, for example, The values of - # VLLM_PP_LAYER_PARTITION will affect the computation graph. - env_hash = envs.compute_hash() - factors.append(env_hash) - - # 1. factors come from the vllm_config (it mainly summarizes how the - # model is created) - config_hash = vllm_config.compute_hash() - factors.append(config_hash) - + factors = compilation_config_hash_factors(vllm_config) # 2. factors come from the code files that are traced by Dynamo ( # it mainly summarizes how the model is used in forward pass) - forward_code_files = list(sorted(self.compilation_config.traced_files)) + code_hash = _compute_code_hash(self.compilation_config.traced_files) self.compilation_config.traced_files.clear() - logger.debug( - "Traced files (to be considered for compilation cache):\n%s", - "\n".join(forward_code_files), - ) - hash_content = [] - for filepath in forward_code_files: - hash_content.append(filepath) - if filepath == "<string>": - # This means the function was dynamically generated, with - # e.g. exec(). We can't actually check these. - continue - with open(filepath) as f: - hash_content.append(f.read()) - import hashlib - - code_hash = hashlib.md5( - "\n".join(hash_content).encode(), usedforsecurity=False - ).hexdigest() factors.append(code_hash) # 3. compiler hash @@ -593,10 +613,12 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE if disable_cache: - logger.info("vLLM's torch.compile cache is disabled.") + logger.info_once("vLLM's torch.compile cache is disabled.", scope="local") else: - logger.info( - "Using cache directory: %s for vLLM's torch.compile", local_cache_dir + logger.info_once( + "Using cache directory: %s for vLLM's torch.compile", + local_cache_dir, + scope="local", ) self.compiler_manager.initialize_cache( @@ -609,7 +631,9 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: from .monitor import torch_compile_start_time dynamo_time = time.time() - torch_compile_start_time - logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time) + logger.info_once( + "Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local" + ) self.compilation_config.compilation_time += dynamo_time # we control the compilation process, each instance can only be @@ -619,9 +643,14 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.graph = graph self.configure_post_pass() - self.split_gm, self.piecewise_graphs = split_graph( - graph, self.compilation_config.splitting_ops - ) + if self.compilation_config.use_inductor_graph_partition: + # Let Inductor decide partitioning; avoid FX-level pre-splitting. + fx_split_ops: list[str] = [] + else: + fx_split_ops = self.compilation_config.splitting_ops or [] + + resolved_split_ops = resolve_defined_ops(fx_split_ops) + self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops) from torch._dynamo.utils import lazy_format_graph_code @@ -645,7 +674,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: graph_path = os.path.join(local_cache_dir, "computation_graph.py") if not os.path.exists(graph_path): - # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa + # code adapted from + # https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # use `print_readable` because it can include submodules src = ( "from __future__ import annotations\nimport torch\n" @@ -655,7 +685,9 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: with open(graph_path, "w") as f: f.write(src) - logger.debug("Computation graph saved to %s", graph_path) + logger.debug_once( + "Computation graph saved to %s", graph_path, scope="local" + ) self._called = True @@ -663,7 +695,9 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or not self.compilation_config.cudagraph_copy_inputs ): - return self.split_gm + return VllmSerializableFunction( + graph, example_inputs, self.prefix, self.split_gm + ) # if we need to copy input buffers for cudagraph from torch._guards import detect_fake_mode @@ -708,4 +742,6 @@ def copy_and_call(*args): list_args[index] = static_tensor return self.split_gm(*list_args) - return copy_and_call + return VllmSerializableFunction( + graph, example_inputs, self.prefix, copy_and_call + ) diff --git a/vllm/compilation/base_static_graph.py b/vllm/compilation/base_static_graph.py index 6ee82e74963d..12f1ff5bc044 100644 --- a/vllm/compilation/base_static_graph.py +++ b/vllm/compilation/base_static_graph.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Protocol +from collections.abc import Callable +from typing import Any, Protocol from vllm.config import CUDAGraphMode, VllmConfig diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py new file mode 100644 index 000000000000..16e34c2711e9 --- /dev/null +++ b/vllm/compilation/caching.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +import inspect +import os +import pickle +from unittest.mock import patch + +import torch +from torch.utils import _pytree as pytree + +import vllm.envs as envs +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.logger import init_logger + +try: + from torch._dynamo.aot_compile import SerializableCallable +except ImportError: + SerializableCallable = object + +assert isinstance(SerializableCallable, type) + +logger = init_logger(__name__) + + +class VllmSerializableFunction(SerializableCallable): + """ + A wrapper around a compiled function by vllm. It will forward the tensor + inputs to the compiled function and return the result. + It also implements a serialization interface to support PyTorch's precompile + with custom backend, so that we can save and load the compiled function on + disk. There's no need to wrap around the compiled function if we don't want + to serialize them in particular cases. + Right now serialization for the custom backend is done via + serializing the Dynamo fx graph plus example inputs. + """ + + def __init__(self, graph_module, example_inputs, prefix, optimized_call): + assert isinstance(graph_module, torch.fx.GraphModule) + self.graph_module = graph_module + self.example_inputs = example_inputs + self.prefix = prefix + self.optimized_call = optimized_call + self.shape_env = None + sym_input = next( + (i for i in self.example_inputs if isinstance(i, torch.SymInt)), None + ) + if sym_input is not None: + self.shape_env = sym_input.node.shape_env + + def __call__(self, *args, **kwargs): + return self.optimized_call(*args, **kwargs) + + @classmethod + def serialize_compile_artifacts( + cls, compiled_fn: "VllmSerializableFunction" + ) -> bytes: + import sympy + from torch._subclasses import FakeTensorMode + from torch.fx._graph_pickler import GraphPickler, Options + + state = compiled_fn.__dict__.copy() + state.pop("optimized_call") + state.pop("shape_env") + for node in state["graph_module"].graph.nodes: + node.meta.pop("source_fn_stack", None) + node.meta.pop("nn_module_stack", None) + + graph_reducer_override = GraphPickler.reducer_override + + def _graph_reducer_override(self, obj): + if ( + inspect.isclass(obj) + and issubclass(obj, sympy.Function) + and hasattr(obj, "_torch_unpickler") + ): + return obj._torch_unpickler, (obj._torch_handler_name,) + if isinstance(obj, FakeTensorMode): + return type(None), () + return graph_reducer_override(self, obj) + + # Mask off tensor inputs since they are large and not needed. + state["example_inputs"] = pytree.tree_map_only( + torch.Tensor, lambda _: None, state["example_inputs"] + ) + with patch.object(GraphPickler, "reducer_override", _graph_reducer_override): + state["graph_module"] = GraphPickler.dumps( + state["graph_module"], Options(ops_filter=None) + ) + state["example_inputs"] = GraphPickler.dumps(state["example_inputs"]) + return pickle.dumps(state) + + @classmethod + def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction": + from torch._guards import TracingContext, tracing + from torch._subclasses import FakeTensorMode + from torch.fx._graph_pickler import GraphPickler + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + from vllm.compilation.backends import VllmBackend + + state = pickle.loads(data) + fake_mode = FakeTensorMode(shape_env=ShapeEnv()) + state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode) + state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode) + vllm_backend = VllmBackend(get_current_vllm_config(), state["prefix"]) + + def optimized_call(*example_inputs): + """ + On the first run of the optimized call, we rerun the compiler + backend which should result in a cache hit. After the backend + call returns, we just do a one-time replacement of the optimized + call with the compiled function, so that subsequent calls are on + the AOT compiled path. + """ + compile_inputs = [ + inp or example_inputs[i] for i, inp in enumerate(fn.example_inputs) + ] + with tracing(TracingContext(fake_mode)): + fn.optimized_call = vllm_backend( + state["graph_module"], compile_inputs + ).optimized_call + return fn.optimized_call(*example_inputs) + + fn = cls(**state, optimized_call=optimized_call) + return fn + + @property + def co_name(self): + """ + Used for depyf debugging. + """ + return "VllmSerializableFunction" + + +def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]: + factors = [] + # 0. factors come from the env, for example, The values of + # VLLM_PP_LAYER_PARTITION will affect the computation graph. + env_hash = envs.compute_hash() + factors.append(env_hash) + + # 1. factors come from the vllm_config (it mainly summarizes how the + # model is created) + config_hash = vllm_config.compute_hash() + factors.append(config_hash) + return factors + + +def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str: + items = list(sorted(file_contents.items(), key=lambda x: x[0])) + hash_content = [] + for filepath, content in items: + hash_content.append(filepath) + if filepath == "<string>": + # This means the function was dynamically generated, with + # e.g. exec(). We can't actually check these. + continue + hash_content.append(content) + return hashlib.md5( + "\n".join(hash_content).encode(), usedforsecurity=False + ).hexdigest() + + +def _compute_code_hash(files: set[str]) -> str: + logger.debug( + "Traced files (to be considered for compilation cache):\n%s", "\n".join(files) + ) + file_contents = {} + for filepath in files: + # Skip files that don't exist (e.g., <string>, <frozen modules>, etc.) + if not os.path.isfile(filepath): + file_contents[filepath] = "" + else: + with open(filepath) as f: + file_contents[filepath] = f.read() + return _compute_code_hash_with_content(file_contents) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 01fd9f9a1c8e..7294ddce64ba 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from importlib.util import find_spec -from typing import Optional import torch import torch._inductor.pattern_matcher as pm @@ -18,10 +17,14 @@ get_tensor_model_parallel_world_size, ) from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, +) from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass FP8_DTYPE = current_platform.fp8_dtype() @@ -42,11 +45,8 @@ logger = init_logger(__name__) -ALLREDUCE_OP = torch.ops.vllm.all_reduce.default -RMS_OP = torch.ops._C.rms_norm.default -RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default -STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default -STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default +if hasattr(torch.ops._C, "scaled_fp4_quant"): + STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default class BasePattern: @@ -169,15 +169,23 @@ def replacement( scale_a: torch.Tensor, scale_b: torch.Tensor, ) -> torch.Tensor: - gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( + # Calculate output shape: input @ mat2 with scatter_dim reduced + output_shape = [*input.shape[:-1], mat2.shape[1]] + scatter_dim = 0 + gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter( input, mat2, scale_a, scale_b, "avg", - scatter_dim=0, - out_dtype=self.dtype, - group_name=self.tp.device_group.group_name, + scatter_dim, # orig_scatter_dim + scatter_dim, # scatter_dim_after_maybe_reshape + self.tp.device_group.group_name, + output_shape, + None, # bias + None, # result_scale + self.dtype, # out_dtype + False, # use_fast_accum ) return gemm_rs @@ -296,15 +304,23 @@ def replacement( scale_b: torch.Tensor, cutlass_mm_output: torch.Tensor, ) -> torch.Tensor: - gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( + # Calculate output shape: input @ mat2 with scatter_dim reduced + output_shape = [*input.shape[:-1], mat2.shape[1]] + scatter_dim = 0 + gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter( input, mat2, scale_a, scale_b, "avg", - scatter_dim=0, - out_dtype=self.dtype, - group_name=self.tp.device_group.group_name, + scatter_dim, # orig_scatter_dim + scatter_dim, # scatter_dim_after_maybe_reshape + self.tp.device_group.group_name, + output_shape, + None, # bias + None, # result_scale + self.dtype, # out_dtype + False, # use_fast_accum ) return gemm_rs @@ -416,8 +432,15 @@ def __init__(self, config: VllmConfig): self.dump_patterns(config, self.patterns) - def is_applicable_for_shape(self, shape: Optional[int]) -> bool: - # only do replace for specific shapes + def is_applicable(self, shape: int | None) -> bool: + # This pass is applied on top of the sequence parallelism pass. + # It inherits the same applicability condition as `SequenceParallelismPass`. + # See `SequenceParallelismPass.is_applicable` for more details. + if ( + not self.compilation_config.splitting_ops + or self.compilation_config.use_inductor_graph_partition + ): + return True tp_size = get_tensor_model_parallel_world_size() return shape is not None and shape % tp_size == 0 @@ -469,10 +492,10 @@ def call_trtllm_fused_allreduce_norm( max_token_num: int, pattern_code: int, fuse_rms_quant: bool, - norm_out: Optional[torch.Tensor] = None, - quant_out: Optional[torch.Tensor] = None, - scale_out: Optional[torch.Tensor] = None, - scale_factor: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, + scale_out: torch.Tensor | None = None, + scale_factor: torch.Tensor | None = None, ) -> None: num_tokens, hidden_size = allreduce_in.shape element_size = allreduce_in.element_size() @@ -573,10 +596,10 @@ def call_trtllm_fused_allreduce_norm_fake( max_token_num: int, pattern_code: int, fuse_rms_quant: bool, - norm_out: Optional[torch.Tensor] = None, - quant_out: Optional[torch.Tensor] = None, - scale_out: Optional[torch.Tensor] = None, - scale_factor: Optional[torch.Tensor] = None, + norm_out: torch.Tensor | None = None, + quant_out: torch.Tensor | None = None, + scale_out: torch.Tensor | None = None, + scale_factor: torch.Tensor | None = None, ) -> None: pass @@ -647,33 +670,24 @@ def __init__( super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) def get_inputs(self): - input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - rms_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - weight = torch.empty([4], device=self.device, dtype=self.dtype) + input, weight = self.rmsnorm_matcher.inputs() - return [input, rms_result, weight] + # input goes through allreduce first, always 16-bit + return [input.to(self.dtype), weight] def register(self, pm_pass: PatternMatcherPass): - def pattern( - input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor - ): + def pattern(input: torch.Tensor, weight: torch.Tensor): allreduce_output = tensor_model_parallel_all_reduce(input) - rms = auto_functionalized( - RMS_OP, - result=rms_result, - input=allreduce_output, - weight=weight, - epsilon=self.epsilon, - ) - # rms_result, allreduce_output - return rms[1], allreduce_output + rms = self.rmsnorm_matcher(allreduce_output, weight) - def replacement( - input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor - ): + return rms, allreduce_output + + def replacement(input: torch.Tensor, weight: torch.Tensor): residual = torch.zeros_like(input) + rms_result = torch.empty_like(input) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -711,29 +725,19 @@ def __init__( super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) def get_inputs(self): - input = torch.empty([4, 4], device=self.device, dtype=self.dtype) - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) - return [ - residual, - input, - weight, - ] + input, residual, weight = self.rmsnorm_matcher.inputs() + + # input goes through allreduce first, always 16-bit + return [residual, input.to(self.dtype), weight] def register(self, pm_pass: PatternMatcherPass): def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor): allreduce_output = tensor_model_parallel_all_reduce(input) - rms = auto_functionalized( - RMS_ADD_OP, - input=allreduce_output, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) - # input, residual - return rms[1], rms[2] + rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) + return rms, residual def replacement( residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor @@ -757,6 +761,18 @@ def replacement( pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass ) + # Same pattern, but only return the output and not residual + # (helpful for end of graph where residual is not used again) + first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0] + + pm.register_replacement( + first_return_only(pattern), + first_return_only(replacement), + self.get_inputs(), + pm.fwd_only, + pm_pass, + ) + class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): """ @@ -777,60 +793,37 @@ def __init__( self.epsilon = epsilon self.allreduce_params = allreduce_params self.quant_dtype = torch.float8_e4m3fn + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def register(self, pm_pass: PatternMatcherPass): def get_inputs(): - input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) - rmsnorm_result = torch.empty( - [1, 8, 4], device=self.device, dtype=self.dtype - ) - quant_result = torch.empty( - [1, 8, 4], device=self.device, dtype=self.quant_dtype - ) - weight = torch.empty([4], device=self.device, dtype=self.dtype) - scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) - return [input, rmsnorm_result, quant_result, weight, scale] + input, weight = self.rmsnorm_matcher.inputs() + _, scale = self.quant_matcher.inputs() + + # input goes through allreduce first, always 16-bit + return [input.to(self.dtype), weight, scale] def pattern( input: torch.Tensor, - rmsnorm_result: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): all_reduce = tensor_model_parallel_all_reduce(input) - rmsnorm_out_tuple = auto_functionalized( - RMS_OP, - result=rmsnorm_result, - input=all_reduce, - weight=weight, - epsilon=self.epsilon, - ) + rms = self.rmsnorm_matcher(all_reduce, weight) + quant, _ = self.quant_matcher(rms, scale) + return quant, all_reduce - quant_out_tuple = auto_functionalized( - STATIC_FP8_QUANT_OP, - result=quant_result, - input=rmsnorm_out_tuple[1], - scale=scale, - ) - - # quant_out, allreduce_output - return quant_out_tuple[1], all_reduce - - def replacement( - input: torch.Tensor, - result_rms: torch.Tensor, - quant_result: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): + def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): residual = torch.zeros_like(input) + result_rms = torch.empty_like(input) + result_quant = torch.empty_like(input, dtype=self.quant_dtype) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, residual=residual, norm_out=result_rms, - quant_out=quant_result, + quant_out=result_quant, scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, @@ -870,64 +863,42 @@ def __init__( self.allreduce_params = allreduce_params self.quant_dtype = torch.float8_e4m3fn + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) + def register(self, pm_pass: PatternMatcherPass): def get_inputs(): - input = torch.empty([4, 4], device=self.device, dtype=self.dtype) + input, residual, weight = self.rmsnorm_matcher.inputs() + _, scale = self.quant_matcher.inputs() - residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) - quant_result = torch.empty( - [4, 4], device=self.device, dtype=self.quant_dtype - ) - scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) - - return [ - quant_result, - residual, - input, - weight, - scale, - ] + # input goes through allreduce first, always 16-bit + return [residual, input.to(self.dtype), weight, scale] def pattern( - quant_result: torch.Tensor, residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): allreduce_output = tensor_model_parallel_all_reduce(input) + rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual) + quant, _ = self.quant_matcher(rms, scale) - fused_add_rmsnorm_out_tuple = auto_functionalized( - RMS_ADD_OP, - input=allreduce_output, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) - quant_out_tuple = auto_functionalized( - STATIC_FP8_QUANT_OP, - result=quant_result, - input=fused_add_rmsnorm_out_tuple[1], - scale=scale, - ) - - # quant_out, allreduce_output - return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2] + return quant, res def replacement( - quant_result: torch.Tensor, residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, ): + result_quant = torch.empty_like(input, dtype=self.quant_dtype) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, residual=residual, norm_out=None, - quant_out=quant_result, + quant_out=result_quant, scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, @@ -964,14 +935,11 @@ def __init__( super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.rmsnorm_matcher = MatcherRMSNorm(epsilon) def register(self, pm_pass: PatternMatcherPass): def get_inputs(): input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype) - - rmsnorm_result = torch.empty( - [1, 16, 16], device=self.device, dtype=self.dtype - ) quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8) input_global_scale = torch.empty( [1, 1], device=self.device, dtype=torch.float32 @@ -979,36 +947,21 @@ def get_inputs(): weight = torch.empty([16], device=self.device, dtype=self.dtype) output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32) - return [ - input, - rmsnorm_result, - quant_result, - weight, - input_global_scale, - output_scale, - ] + return [input, quant_result, weight, input_global_scale, output_scale] def pattern( input: torch.Tensor, - rmsnorm_result: torch.Tensor, quant_result: torch.Tensor, weight: torch.Tensor, input_global_scale: torch.Tensor, output_scale: torch.Tensor, ): all_reduce = tensor_model_parallel_all_reduce(input) - rmsnorm_out_tuple = auto_functionalized( - RMS_OP, - result=rmsnorm_result, - input=all_reduce, - weight=weight, - epsilon=self.epsilon, - ) - + rms = self.rmsnorm_matcher(all_reduce, weight) quant_out_tuple = auto_functionalized( STATIC_FP4_QUANT_OP, output=quant_result, - input=rmsnorm_out_tuple[1], + input=rms, output_scale=output_scale, input_scale=input_global_scale, ) @@ -1018,13 +971,13 @@ def pattern( def replacement( input: torch.Tensor, - result_rms: torch.Tensor, quant_result: torch.Tensor, weight: torch.Tensor, input_global_scale: torch.Tensor, output_scale: torch.Tensor, ): residual = torch.zeros_like(input) + result_rms = torch.empty_like(input) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -1068,6 +1021,7 @@ def __init__( super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) def register(self, pm_pass: PatternMatcherPass): def get_inputs(): @@ -1099,28 +1053,17 @@ def pattern( input_global_scale: torch.Tensor, ): allreduce_output = tensor_model_parallel_all_reduce(input) - - fused_add_rmsnorm_out_tuple = auto_functionalized( - RMS_ADD_OP, - input=allreduce_output, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) + rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) quant_out_tuple = auto_functionalized( STATIC_FP4_QUANT_OP, output=quant_result, - input=fused_add_rmsnorm_out_tuple[1], + input=rms, output_scale=output_scale, input_scale=input_global_scale, ) # quant_out, allreduce_output, output_scale - return ( - quant_out_tuple[1], - fused_add_rmsnorm_out_tuple[2], - quant_out_tuple[2], - ) + return quant_out_tuple[1], residual, quant_out_tuple[2] def replacement( quant_result: torch.Tensor, diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 3b5fecaf189b..0a3f0769db94 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -4,8 +4,9 @@ import copy import hashlib import os +from collections.abc import Callable from contextlib import ExitStack -from typing import Any, Callable, Optional +from typing import Any from unittest.mock import patch import torch @@ -15,9 +16,7 @@ import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig -from vllm.utils import is_torch_equal_or_newer - -from .inductor_pass import pass_context +from vllm.utils.torch_utils import is_torch_equal_or_newer class CompilerInterface: @@ -64,9 +63,9 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: Optional[int] = None, - key: Optional[str] = None, - ) -> tuple[Optional[Callable], Optional[Any]]: + runtime_shape: int | None = None, + key: str | None = None, + ) -> tuple[Callable | None, Any | None]: """ Compile the graph with the given example inputs and compiler config, with a runtime shape. If the `runtime_shape` is None, it means @@ -99,7 +98,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: Optional[int] = None, + runtime_shape: int | None = None, ) -> Callable: """ Load the compiled function from the handle. @@ -193,14 +192,15 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: Optional[int] = None, - key: Optional[str] = None, - ) -> tuple[Optional[Callable], Optional[Any]]: + runtime_shape: int | None = None, + key: str | None = None, + ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 current_config = {} if compiler_config is not None: current_config.update(compiler_config) set_inductor_config(current_config, runtime_shape) + set_functorch_config() if isinstance(runtime_shape, int): dynamic_shapes = "from_example_inputs" @@ -209,13 +209,12 @@ def compile( from torch._inductor import standalone_compile - with pass_context(runtime_shape): - compiled_graph = standalone_compile( - graph, - example_inputs, - dynamic_shapes=dynamic_shapes, - options={"config_patches": current_config}, - ) + compiled_graph = standalone_compile( + graph, + example_inputs, + dynamic_shapes=dynamic_shapes, + options={"config_patches": current_config}, + ) # Save the compiled artifact to disk in the specified path assert key is not None @@ -231,7 +230,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: Optional[int] = None, + runtime_shape: int | None = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -295,9 +294,9 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: Optional[int] = None, - key: Optional[str] = None, - ) -> tuple[Optional[Callable], Optional[Any]]: + runtime_shape: int | None = None, + key: str | None = None, + ) -> tuple[Callable | None, Any | None]: compilation_counter.num_inductor_compiles += 1 from torch._inductor.compile_fx import compile_fx @@ -310,6 +309,7 @@ def compile( current_config["fx_graph_remote_cache"] = False set_inductor_config(current_config, runtime_shape) + set_functorch_config() # inductor can inplace modify the graph, so we need to copy it # see https://github.com/pytorch/pytorch/issues/138980 @@ -332,7 +332,10 @@ def hijack_load(*args, **kwargs): nonlocal file_path compiled_fn = inductor_compiled_graph.current_callable file_path = compiled_fn.__code__.co_filename # noqa - if not file_path.startswith(self.base_cache_dir): + if ( + not file_path.startswith(self.base_cache_dir) + and compiled_fn.__closure__ is not None + ): # hooked in the align_inputs_from_check_idxs function # in torch/_inductor/utils.py for cell in compiled_fn.__closure__: @@ -359,7 +362,10 @@ def hijacked_compile_fx_inner(*args, **kwargs): nonlocal file_path compiled_fn = inductor_compiled_graph.current_callable file_path = compiled_fn.__code__.co_filename # noqa - if not file_path.startswith(self.base_cache_dir): + if ( + not file_path.startswith(self.base_cache_dir) + and compiled_fn.__closure__ is not None + ): # hooked in the align_inputs_from_check_idxs function # in torch/_inductor/utils.py for cell in compiled_fn.__closure__: @@ -456,13 +462,12 @@ def _get_shape_env() -> AlwaysHitShapeEnv: torch._functorch.config.patch(enable_remote_autograd_cache=False) ) - with pass_context(runtime_shape): - compiled_graph = compile_fx( - graph, - example_inputs, - inner_compile=hijacked_compile_fx_inner, - config_patches=current_config, - ) + compiled_graph = compile_fx( + graph, + example_inputs, + inner_compile=hijacked_compile_fx_inner, + config_patches=current_config, + ) # We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch # compilation cache. So turn off the checks if we disable the @@ -488,7 +493,7 @@ def load( graph: fx.GraphModule, example_inputs: list[Any], graph_index: int, - runtime_shape: Optional[int] = None, + runtime_shape: int | None = None, ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) @@ -570,7 +575,7 @@ def metrics_context(self) -> contextlib.AbstractContextManager: Because it is re-entrant, we always set it (even if entering via Dynamo and the context was already entered). We might want to revisit if it - should be set at a different level of compilation. + should be set at a different mode of compilation. This is likely a bug in PyTorch: public APIs should not rely on manually setting up internal contexts. But we also rely on non-public @@ -594,6 +599,10 @@ def set_inductor_config(config, runtime_shape): ) +def set_functorch_config(): + torch._functorch.config.bundled_autograd_cache = False + + class EagerAdaptor(CompilerInterface): name = "eager" @@ -602,9 +611,9 @@ def compile( graph: fx.GraphModule, example_inputs: list[Any], compiler_config: dict[str, Any], - runtime_shape: Optional[int] = None, - key: Optional[str] = None, - ) -> tuple[Optional[Callable], Optional[Any]]: + runtime_shape: int | None = None, + key: str | None = None, + ) -> tuple[Callable | None, Any | None]: compilation_counter.num_eager_compiles += 1 # we don't need to compile the graph, just return the graph itself. # It does not support caching, return None for the handle. diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index 9e8de831bcb2..20918099f169 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -27,8 +27,8 @@ class CompilationCounter: num_cache_entries_updated: int = 0 # The number of standalone_compile compiled artifacts saved num_compiled_artifacts_saved: int = 0 - # Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS - dynamo_as_is_count: int = 0 + # Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE + stock_torch_compile_count: int = 0 def clone(self) -> "CompilationCounter": return copy.deepcopy(self) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 4c3ac9e56a37..a2e0abfebc2c 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses +from collections.abc import Callable from contextlib import ExitStack -from typing import Any, Callable, Optional +from typing import Any from unittest.mock import patch import torch @@ -16,7 +17,7 @@ from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import weak_ref_tensors +from vllm.utils.torch_utils import weak_ref_tensors logger = init_logger(__name__) @@ -24,12 +25,12 @@ @dataclasses.dataclass class CUDAGraphEntry: batch_descriptor: BatchDescriptor - cudagraph: Optional[torch.cuda.CUDAGraph] = None - output: Optional[Any] = None + cudagraph: torch.cuda.CUDAGraph | None = None + output: Any | None = None # for cudagraph debugging, track the input addresses # during capture, and check if they are the same during replay - input_addresses: Optional[list[int]] = None + input_addresses: list[int] | None = None @dataclasses.dataclass @@ -69,7 +70,7 @@ def __init__( runnable: Callable, vllm_config: VllmConfig, runtime_mode: CUDAGraphMode, - cudagraph_options: Optional[CUDAGraphOptions] = None, + cudagraph_options: CUDAGraphOptions | None = None, ): self.runnable = runnable self.vllm_config = vllm_config diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 4f5648d3000a..4a4903035cf9 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -2,8 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import hashlib import inspect -from typing import Callable, Optional, TypeVar, Union, overload +import os +import sys +from collections.abc import Callable +from typing import TypeVar, overload from unittest.mock import patch import torch @@ -11,12 +15,14 @@ from packaging import version from torch._dynamo.symbolic_convert import InliningInstructionTranslator +import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import CompilationLevel, VllmConfig +from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.sequence import IntermediateTensors -from vllm.utils import resolve_obj_by_qualname, supports_dynamo +from vllm.utils.import_utils import resolve_obj_by_qualname +from vllm.utils.torch_utils import supports_dynamo from .monitor import start_monitoring_torch_compile @@ -57,14 +63,14 @@ def _should_ignore_torch_compile(cls) -> bool: @overload def support_torch_compile( *, - enable_if: Optional[Callable[[VllmConfig], bool]] = None, + enable_if: Callable[[VllmConfig], bool] | None = None, ) -> Callable[[_T], _T]: ... @overload def support_torch_compile( *, - dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]], + dynamic_arg_dims: dict[str, int | list[int]] | None, ) -> Callable[[_T], _T]: ... @@ -73,11 +79,11 @@ def support_torch_compile(cls: _T) -> _T: ... def support_torch_compile( - cls: Optional[_T] = None, + cls: _T | None = None, *, - dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None, - enable_if: Optional[Callable[[VllmConfig], bool]] = None, -) -> Union[Callable[[_T], _T], _T]: + dynamic_arg_dims: dict[str, int | list[int]] | None = None, + enable_if: Callable[[VllmConfig], bool] | None = None, +) -> Callable[[_T], _T] | _T: """ A decorator to add support for compiling the forward method of a class. @@ -132,8 +138,8 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ... """ def cls_decorator_helper(cls: _T) -> _T: - # helper to pass `dynamic_arg_dims`` to `_support_torch_compile`` - # to avoid too much indentation for `_support_torch_compile`` + # helper to pass `dynamic_arg_dims` to `_support_torch_compile` + # to avoid too much indentation for `_support_torch_compile` if not hasattr(cls, "forward"): raise TypeError("decorated class should have a forward method.") sig = inspect.signature(cls.forward) @@ -143,9 +149,9 @@ def cls_decorator_helper(cls: _T) -> _T: for k, v in sig.parameters.items(): if v.annotation in [ torch.Tensor, - Optional[torch.Tensor], + torch.Tensor | None, IntermediateTensors, - Optional[IntermediateTensors], + IntermediateTensors | None, ]: inferred_dynamic_arg_dims[k] = 0 @@ -176,10 +182,37 @@ def cls_decorator_helper(cls: _T) -> _T: return cls_decorator_helper +def _model_hash_key(fn) -> str: + import vllm + + sha256_hash = hashlib.sha256() + sha256_hash.update(vllm.__version__.encode()) + sha256_hash.update(fn.__qualname__.encode()) + sha256_hash.update(str(fn.__code__.co_firstlineno).encode()) + return sha256_hash.hexdigest() + + +def _verify_source_unchanged(source_info, vllm_config) -> None: + from .caching import _compute_code_hash, _compute_code_hash_with_content + + file_contents = {} + for source in source_info.inlined_sources: + module = sys.modules[source.module] + file = inspect.getfile(module) + vllm_config.compilation_config.traced_files.add(file) + file_contents[file] = source.content + expected_checksum = _compute_code_hash_with_content(file_contents) + actual_checksum = _compute_code_hash(set(file_contents.keys())) + if expected_checksum != actual_checksum: + raise RuntimeError( + "Source code has changed since the last compilation. Recompiling the model." + ) + + def _support_torch_compile( cls: _T, - dynamic_arg_dims: dict[str, Union[int, list[int]]], - enable_if: Optional[Callable[[VllmConfig], bool]] = None, + dynamic_arg_dims: dict[str, int | list[int]], + enable_if: Callable[[VllmConfig], bool] | None = None, ) -> _T: """ A decorator to add support for compiling the forward method of a class. @@ -201,11 +234,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) self.vllm_config = vllm_config enable_compile = enable_if is None or enable_if(vllm_config) - # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner + # for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner # will handle the compilation, so we don't need to do anything here. self.do_not_compile = ( - vllm_config.compilation_config.level - in [CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS] + vllm_config.compilation_config.mode + in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE] or not supports_dynamo() or _should_ignore_torch_compile(self.__class__) or not enable_compile @@ -215,7 +248,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs): compilation_counter.num_models_seen += 1 TorchCompileWrapperWithCustomDispatcher.__init__( - self, compilation_level=vllm_config.compilation_config.level + self, compilation_mode=vllm_config.compilation_config.mode ) cls.__init__ = __init__ @@ -227,6 +260,64 @@ def __call__(self, *args, **kwargs): if self.do_not_compile or torch.compiler.is_compiling(): return self.forward(*args, **kwargs) + if getattr(self, "aot_compiled_fn", None) is not None: + return self.aot_compiled_fn(self, *args, **kwargs) + + cache_dir = None + aot_compilation_path = None + if envs.VLLM_USE_AOT_COMPILE: + """ + When using torch.compile in AOT mode, we store the cache artifacts + under VLLM_CACHE_ROOT/torch_aot_compile/{hash}/rank_i_j. The {hash} + contains all of the factors except for the source files being + traced through, because we don't actually know which source files + to check at this point (before dynamo runs). + On loading we will actually look at the source files being traced + through. If any source file have changed (compared with the + serialized backend artifacts), then we need to generate a new AOT + compile artifact from scratch. + """ + from .caching import compilation_config_hash_factors + + factors: list[str] = compilation_config_hash_factors(self.vllm_config) + + factors.append(_model_hash_key(self.forward)) + hash_key = hashlib.sha256(str(factors).encode()).hexdigest() + + cache_dir = os.path.join( + envs.VLLM_CACHE_ROOT, + "torch_aot_compile", + hash_key, + ) + + rank = self.vllm_config.parallel_config.rank + dp_rank = self.vllm_config.parallel_config.data_parallel_rank + cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") + aot_compilation_path = os.path.join(cache_dir, "model") + try: + with ( + set_current_vllm_config(self.vllm_config), + open(aot_compilation_path, "rb") as f, + ): + start_monitoring_torch_compile(self.vllm_config) + loaded_fn = torch.compiler.load_compiled_function(f) + _verify_source_unchanged(loaded_fn.source_info(), self.vllm_config) + self.aot_compiled_fn = loaded_fn + except Exception as e: + if os.path.exists(aot_compilation_path): + logger.warning( + "Cannot load aot compilation from path %s, error: %s", + aot_compilation_path, + str(e), + ) + if envs.VLLM_FORCE_AOT_LOAD: + raise e + if getattr(self, "aot_compiled_fn", None) is not None: + logger.info( + "Directly load AOT compilation from path %s", aot_compilation_path + ) + return self.aot_compiled_fn(self, *args, **kwargs) + # the first compilation needs to have dynamic shapes marked if len(self.compiled_codes) < 1: sig = inspect.signature(self.__class__.forward) @@ -275,15 +366,15 @@ def __call__(self, *args, **kwargs): ) # 2. every time Dynamo sees a function call, it will inline - # the function by calling InliningInstructionTranslator.inline_call + # the function by calling InliningInstructionTranslator.inline_call_ # we hijack this function to know all the functions called # during Dynamo tracing, and their corresponding files - inline_call = InliningInstructionTranslator.inline_call + inline_call = InliningInstructionTranslator.inline_call_ - def patched_inline_call(parent, func, args, kwargs): - code = func.get_code() + def patched_inline_call(self_): + code = self_.f_code self.vllm_config.compilation_config.traced_files.add(code.co_filename) - return inline_call(parent, func, args, kwargs) + return inline_call(self_) # Disable the C++ compilation of symbolic shape guards. C++-fication # of symbolic shape guards can improve guard overhead. But, since @@ -300,13 +391,21 @@ def patched_inline_call(parent, func, args, kwargs): with ( patch.object( - InliningInstructionTranslator, "inline_call", patched_inline_call + InliningInstructionTranslator, "inline_call_", patched_inline_call ), torch._dynamo.config.patch(**dynamo_config_patches), maybe_use_cudagraph_partition_wrapper(self.vllm_config), _torch27_patch_tensor_subclasses(), ): - output = self.compiled_callable(*args, **kwargs) + if envs.VLLM_USE_AOT_COMPILE: + self.aot_compiled_fn = self.aot_compile(*args, **kwargs) + output = self.aot_compiled_fn(self, *args, **kwargs) + assert aot_compilation_path is not None + assert cache_dir is not None + os.makedirs(cache_dir, exist_ok=True) + self.aot_compiled_fn.save_compiled_function(aot_compilation_path) + else: + output = self.compiled_callable(*args, **kwargs) return output # usually, capturing the model once is enough, and then we can diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 0dffb343f9a2..29462d9ff0e5 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -3,7 +3,6 @@ import operator from collections.abc import Iterable -from typing import Optional, Union import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized @@ -150,7 +149,7 @@ def __call__(self, graph: torch.fx.Graph): ) self.nodes_to_remove.clear() - def _remove(self, node_or_nodes: Union[torch.fx.Node, Iterable[torch.fx.Node]]): + def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]): """ Stage a node (or nodes) for removal at the end of the pass. """ @@ -163,8 +162,8 @@ def defunctionalize( self, graph: torch.fx.Graph, node: torch.fx.Node, - mutated_args: dict[int, Union[torch.fx.Node, str]], - args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None, + mutated_args: dict[int, torch.fx.Node | str], + args: tuple[torch.fx.Node | str, ...] | None = None, ): """ De-functionalize a node by replacing it with a call to the original. @@ -176,7 +175,7 @@ def defunctionalize( self._remove(node) def replace_users_with_mutated_args( - self, node: torch.fx.Node, mutated_args: dict[int, Union[torch.fx.Node, str]] + self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str] ): """ Replace all getitem users of the auto-functionalized node with the @@ -207,7 +206,7 @@ def insert_defunctionalized( self, graph: torch.fx.Graph, node: torch.fx.Node, - args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None, + args: tuple[torch.fx.Node | str, ...] | None = None, ): """ Insert a new defunctionalized node into the graph before node. diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index df54e94a03db..8f0ad2d69fbe 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -9,7 +9,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from torch._ops import OpOverload -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -24,6 +24,7 @@ from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -92,13 +93,19 @@ class RMSNormQuantPattern: def __init__(self, epsilon: float, key: FusedRMSQuantKey): self.epsilon = epsilon self.quant_dtype = key.quant.dtype - - assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}" - self.QUANT_OP = QUANT_OPS[key.quant] + config = get_current_vllm_config() + self.model_dtype = config.model_config.dtype if config.model_config else None assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] + self.rmsnorm_matcher = ( + MatcherRMSNorm(epsilon) + if not key.fused_add + else MatcherFusedAddRMSNorm(epsilon) + ) + self.quant_matcher = MatcherQuantFP8(key.quant) + class RMSNormStaticQuantPattern(RMSNormQuantPattern): def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): @@ -112,34 +119,18 @@ def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): def register(self, pm_pass: PatternMatcherPass): # Cannot use methods, as the self argument affects tracing - def pattern( - result: torch.Tensor, - result_rms: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): - at1 = auto_functionalized( - RMS_OP, - result=result_rms, - input=input, - weight=weight, - epsilon=self.epsilon, - ) - at2 = auto_functionalized( - self.QUANT_OP, result=result, input=at1[1], scale=scale - ) + def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + result_rms = self.rmsnorm_matcher(input, weight) + return self.quant_matcher(result_rms, scale)[0] - # result - return at2[1] + def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=self.model_dtype) - def replacement( - result: torch.Tensor, - result_rms: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): + result = torch.empty( + input.shape, device=input.device, dtype=self.quant_dtype + ) at = auto_functionalized( self.FUSED_OP, result=result, @@ -153,12 +144,11 @@ def replacement( return at[1] inputs = [ - torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result - empty_bf16(5, 4), # result_rms - empty_bf16(5, 4), # input - empty_bf16(1, 5), # weight - empty_fp32(1, 1), # scale + # input, weight + *self.rmsnorm_matcher.inputs(), + self.quant_matcher.inputs()[1], # scale ] + pattern(*inputs) pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) @@ -175,33 +165,27 @@ def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): def register(self, pm_pass: PatternMatcherPass): def pattern( - result: torch.Tensor, input: torch.Tensor, - residual: torch.Tensor, weight: torch.Tensor, + residual: torch.Tensor, scale: torch.Tensor, ): - at = auto_functionalized( - RMS_ADD_OP, - input=input, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) - at1 = auto_functionalized( - self.QUANT_OP, result=result, input=at[1], scale=scale - ) + result_rms, residual = self.rmsnorm_matcher(input, weight, residual) + result, _ = self.quant_matcher(result_rms, scale) - # result, residual - return at1[1], at[2] + return result, residual def replacement( - result: torch.Tensor, input: torch.Tensor, - residual: torch.Tensor, weight: torch.Tensor, + residual: torch.Tensor, scale: torch.Tensor, ): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=self.model_dtype) + + result = torch.empty_like(input, dtype=self.quant_dtype) at = auto_functionalized( self.FUSED_OP, result=result, @@ -216,11 +200,9 @@ def replacement( return at[1], at[2] inputs = [ - torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result - empty_bf16(5, 4), # input - empty_bf16(5, 4), # residual - empty_bf16(1, 5), # weight - empty_fp32(1, 1), # scale + # input, weight, residual + *self.rmsnorm_matcher.inputs(), + self.quant_matcher.inputs()[1], # scale ] pm.register_replacement( @@ -248,34 +230,18 @@ def __init__( super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - def pattern( - result: torch.Tensor, - result_rms: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): - at1 = auto_functionalized( - RMS_OP, - result=result_rms, - input=input, - weight=weight, - epsilon=self.epsilon, - ) - at2 = auto_functionalized( - self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None - ) - + def pattern(input: torch.Tensor, weight: torch.Tensor): + result_rms = self.rmsnorm_matcher(input, weight) # result, scale - return at2[1], at2[2] + return self.quant_matcher(result_rms) - def replacement( - result: torch.Tensor, - result_rms: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): + def replacement(input: torch.Tensor, weight: torch.Tensor): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=self.model_dtype) + + result = torch.empty_like(input, dtype=self.quant_dtype) + scale = self.quant_matcher.make_scale(input) at = auto_functionalized( self.FUSED_OP, result=result, @@ -290,18 +256,10 @@ def replacement( # result, scale return at[1], at[2] - inputs = [ - torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result - empty_bf16(5, 4), # result_rms - empty_bf16(5, 4), # input - empty_bf16(1, 5), # weight - empty_fp32(1, 1), # scale - ] - pm.register_replacement( pattern, replacement, - inputs, + self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass, ) @@ -323,34 +281,21 @@ def __init__( super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - def pattern( - result: torch.Tensor, - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - ): - at = auto_functionalized( - RMS_ADD_OP, - input=input, - residual=residual, - weight=weight, - epsilon=self.epsilon, - ) - at1 = auto_functionalized( - self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None - ) + def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor): + result_rms, residual = self.rmsnorm_matcher(input, weight, residual) + result, scale = self.quant_matcher(result_rms) - # result, residual, scale - return at1[1], at[2], at1[2] + return result, residual, scale def replacement( - result: torch.Tensor, - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor ): + # In case we're matching native rms-norm, conversions might be + # optimized out. We convert here just to be safe. + input = input.to(dtype=self.model_dtype) + + result = torch.empty_like(input, dtype=self.quant_dtype) + scale = self.quant_matcher.make_scale(input) at = auto_functionalized( self.FUSED_OP, result=result, @@ -365,18 +310,10 @@ def replacement( # result, residual, scale return at[1], at[3], at[2] - inputs = [ - torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result - empty_bf16(5, 4), # input - empty_bf16(5, 4), # residual - empty_bf16(1, 5), # weight - empty_fp32(1, 1), # scale - ] - pm.register_replacement( pattern, replacement, - inputs, + self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass, ) @@ -396,23 +333,25 @@ def __init__(self, config: VllmConfig): pass_name="rmsnorm_quant_fusion_pass" ) + # Make sure fused add patterns are before simple rms norm, + # as the latter is a subset of the former in torch ops for epsilon in [1e-5, 1e-6]: - # Fuse rms_norm + static fp8 quant - RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) - # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( self.patterns ) - # Fuse rms_norm + dynamic per-token fp8 quant - RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + # Fuse rms_norm + static fp8 quant + RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) # Fuse fused_add_rms_norm + dynamic per-token fp8 quant FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( self.patterns ) + # Fuse rms_norm + dynamic per-token fp8 quant + RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) + self.dump_patterns(config, self.patterns) @VllmInductorPass.time_and_log diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index ae36cef92653..aaf19e6d4235 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -2,9 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from collections.abc import Callable import torch import torch._inductor.pattern_matcher as pm +from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass @@ -20,7 +22,9 @@ from vllm.utils import round_up from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 +from .fx_utils import is_func from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherQuantFP8 from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -66,9 +70,13 @@ def empty_quant(self, *args, **kwargs): return torch.empty(*args, **kwargs) @staticmethod - def wrap_trace_fn(process_fx, trace_fn): + def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]): def wrapped(*args, **kwargs): - return process_fx(trace_fn(*args, **kwargs)) + gm = trace_fn(*args, **kwargs) + for process_fx in process_fx_fns: + process_fx(gm) + + return gm return wrapped @@ -77,7 +85,20 @@ def fx_view_to_reshape(gm: torch.fx.GraphModule): from torch._inductor.fx_passes.post_grad import view_to_reshape view_to_reshape(gm) - return gm + + @staticmethod + def remove_noop_permutes(gm: torch.fx.GraphModule): + for node in gm.graph.nodes: + if not is_func(node, torch.ops.aten.permute.default): + continue + + dims = node.args[1] + if any(dim != i for i, dim in enumerate(dims)): + continue + + # this is now an identity op, remove + node.replace_all_uses_with(node.args[0]) + gm.graph.erase_node(node) def register_if_supported(self, pm_pass: PatternMatcherPass): if self.layer.impl.fused_output_quant_supported(self.quant_key): @@ -108,6 +129,7 @@ def __init__( dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric ) super().__init__(layer, quant_key, dtype) + self.quant_matcher = MatcherQuantFP8(quant_key) def _register(self, pm_pass: PatternMatcherPass): def pattern( @@ -115,7 +137,6 @@ def pattern( k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, - output_quant: torch.Tensor, scale: torch.Tensor, ): at1 = auto_functionalized( @@ -131,17 +152,14 @@ def pattern( attn_out_view = RESHAPE_OP( at1[1], [q.shape[0], self.num_heads * self.head_size] ) - at2 = auto_functionalized( - self.QUANT_OP, result=output_quant, input=attn_out_view, scale=scale - ) - return at2[1] + + return self.quant_matcher(attn_out_view, scale)[0] def replacement( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output_attn: torch.Tensor, - output_quant: torch.Tensor, scale: torch.Tensor, ): # attn output in quant_dtype @@ -164,13 +182,10 @@ def replacement( return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) inputs = [ - self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # q - self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # k - self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # v - self.empty( - 5, self.num_heads, self.head_size, dtype=self.dtype - ), # attn_output - self.empty_quant(5, self.num_heads * self.head_size), # quant_output + self.empty(5, self.num_heads, self.head_size), # q + self.empty(5, self.num_heads, self.head_size), # k + self.empty(5, self.num_heads, self.head_size), # v + self.empty(5, self.num_heads, self.head_size), # attn_output empty_fp32(1, 1), # scale ] @@ -179,7 +194,9 @@ def replacement( replacement, inputs, AttentionQuantPattern.wrap_trace_fn( - AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only + pm.fwd_only, + AttentionQuantPattern.fx_view_to_reshape, + AttentionQuantPattern.remove_noop_permutes, ), pm_pass, ) @@ -279,7 +296,9 @@ def replacement( replacement, inputs, AttentionQuantPattern.wrap_trace_fn( - AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only + pm.fwd_only, + AttentionQuantPattern.fx_view_to_reshape, + AttentionQuantPattern.remove_noop_permutes, ), pm_pass, ) diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py index 114b53c74c48..f2497950fc22 100644 --- a/vllm/compilation/fx_utils.py +++ b/vllm/compilation/fx_utils.py @@ -3,11 +3,10 @@ import operator from collections.abc import Iterable, Iterator -from typing import Optional from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized -from torch._ops import OpOverload +from torch._ops import OpOverload, OpOverloadPacket def is_func(node: fx.Node, target) -> bool: @@ -19,9 +18,7 @@ def is_auto_func(node: fx.Node, op: OpOverload) -> bool: # Returns the first specified node with the given op (if it exists) -def find_specified_fn_maybe( - nodes: Iterable[fx.Node], op: OpOverload -) -> Optional[fx.Node]: +def find_specified_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node | None: for node in nodes: if node.target == op: return node @@ -36,7 +33,7 @@ def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node: # Returns the first auto_functionalized node with the given op (if it exists) -def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> Optional[fx.Node]: +def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node | None: for node in nodes: if is_func(node, auto_functionalized) and node.args[0] == op: # noqa return node @@ -52,7 +49,7 @@ def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node: # Returns the getitem node that extracts the idx-th element from node # (if it exists) -def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]: +def find_getitem_maybe(node: fx.Node, idx: int) -> fx.Node | None: for user in node.users: if is_func(user, operator.getitem) and user.args[1] == idx: return user @@ -67,7 +64,17 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node: # An auto-functionalization-aware utility for finding nodes with a specific op -def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]: +# Also handles op overload packets and finds all overloads +def find_op_nodes( + op: OpOverload | OpOverloadPacket, graph: fx.Graph +) -> Iterator[fx.Node]: + if isinstance(op, OpOverloadPacket): + for overload in op.overloads(): + overload_op = getattr(op, overload) + yield from find_op_nodes(overload_op, graph) + return + + assert isinstance(op, OpOverload) if not op._schema.is_mutable: yield from graph.find_nodes(op="call_function", target=op) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 9085448d2397..9af635a929b4 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -6,14 +6,15 @@ import inspect import json import types +from collections.abc import Callable from contextlib import contextmanager -from typing import Any, Callable, Optional, Union +from typing import Any import torch from torch import fx from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer if is_torch_equal_or_newer("2.6"): from torch._inductor.custom_graph_pass import CustomGraphPass @@ -27,7 +28,7 @@ class PassContext: - def __init__(self, runtime_shape: Optional[int]): + def __init__(self, runtime_shape: int | None): self.runtime_shape = runtime_shape @@ -38,7 +39,7 @@ def get_pass_context() -> PassContext: @contextmanager -def pass_context(runtime_shape: Optional[int]): +def pass_context(runtime_shape: int | None): """A context manager that stores the current pass context, usually it is a list of sizes to specialize. """ @@ -67,7 +68,7 @@ def uuid(self) -> Any: return InductorPass.hash_source(self) @staticmethod - def hash_source(*srcs: Union[str, Any]): + def hash_source(*srcs: str | Any): """ Utility method to hash the sources of functions or objects. :param srcs: strings or objects to add to the hash. @@ -95,7 +96,7 @@ def hash_dict(dict_: dict[Any, Any]): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable_for_shape(self, shape: Optional[int]): + def is_applicable(self, shape: int | None): return True @@ -105,9 +106,7 @@ class CallableInductorPass(InductorPass): implementation of the UUID. """ - def __init__( - self, callable: Callable[[fx.Graph], None], uuid: Optional[Any] = None - ): + def __init__(self, callable: Callable[[fx.Graph], None], uuid: Any | None = None): self.callable = callable self._uuid = self.hash_source(callable) if uuid is None else uuid diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py new file mode 100644 index 000000000000..383fe6033a6d --- /dev/null +++ b/vllm/compilation/matcher_utils.py @@ -0,0 +1,238 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod + +import torch +from torch._higher_order_ops import auto_functionalized +from torch._ops import OpOverload + +from vllm.config import get_current_vllm_config +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + _normalize_quant_group_shape, + kFp8DynamicTensorSym, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, + kNvfp4Quant, +) +from vllm.platforms import current_platform + +RMS_OP = torch.ops._C.rms_norm.default +RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default + +QUANT_OPS: dict[QuantKey, OpOverload] = { + kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 +} + +if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): + QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 + +SILU_MUL_OP = torch.ops._C.silu_and_mul.default + + +class MatcherCustomOp(ABC): + def __init__(self, enabled: bool): + config = get_current_vllm_config() + self.model_dtype = config.model_config.dtype if config.model_config else None + self.device = config.device_config.device if config.device_config else None + + self.enabled = enabled + self.forward = self.forward_custom if enabled else self.forward_native + + @abstractmethod + def forward_custom(self, *args, **kws): + pass + + @abstractmethod + def forward_native(self, *args, **kws): + pass + + def __call__(self, *args, **kws): + return self.forward(*args, **kws) + + def empty(self, *args, **kws): + return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws) + + def empty_f32(self, *args, **kws): + return torch.empty(*args, dtype=torch.float32, device=self.device, **kws) + + def inputs(self) -> list[torch.Tensor]: + """Utility for inputs to the pattern""" + raise NotImplementedError + + +class MatcherRMSNorm(MatcherCustomOp): + def __init__(self, epsilon: float, enabled: bool | None = None): + if enabled is None: + enabled = RMSNorm.enabled() + + super().__init__(enabled) + self.epsilon = epsilon + + def inputs(self): + input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) + weight = self.empty(16) + return [input, weight] + + def forward_custom( + self, + input: torch.Tensor, + weight: torch.Tensor, + ) -> torch.Tensor: + result = torch.empty_like(input) + _, result = auto_functionalized( + RMS_OP, + result=result, + input=input, + weight=weight, + epsilon=self.epsilon, + ) + + return result + + def forward_native( + self, + input: torch.Tensor, + weight: torch.Tensor, + ) -> torch.Tensor: + return RMSNorm.forward_static( + input, self.epsilon, input.size(-1), self.model_dtype, weight + ) + + +class MatcherFusedAddRMSNorm(MatcherCustomOp): + def __init__(self, epsilon: float, enabled: bool | None = None): + if enabled is None: + enabled = RMSNorm.enabled() + + super().__init__(enabled) + self.epsilon = epsilon + + def inputs(self): + input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) + weight = self.empty(16) + residual = self.empty(5, 16) + return [input, weight, residual] + + def forward_custom( + self, + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + _, result, residual = auto_functionalized( + RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) + + return result, residual + + def forward_native( + self, + input: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + return RMSNorm.forward_static( + input, self.epsilon, input.size(-1), self.model_dtype, weight, residual + ) + + +class MatcherQuantFP8(MatcherCustomOp): + def __init__(self, quant_key: QuantKey, enabled: bool | None = None): + if enabled is None: + enabled = QuantFP8.enabled() + + super().__init__(enabled) + self.quant_key = quant_key + assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}" + self.QUANT_OP = QUANT_OPS[quant_key] + + assert quant_key.dtype == current_platform.fp8_dtype(), ( + "Only QuantFP8 supported by" + ) + assert quant_key.scale2 is None + self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape) + + def forward_custom( + self, + input: torch.Tensor, + scale: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + result = torch.empty( + input.shape, device=input.device, dtype=self.quant_key.dtype + ) + + if self.quant_key.scale.static: + assert scale is not None + _, result = auto_functionalized( + self.QUANT_OP, result=result, input=input, scale=scale + ) + return result, scale + else: + assert scale is None + scale = self.make_scale(input) + _, result, scale = auto_functionalized( + self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None + ) + return result, scale + + def forward_native( + self, + input: torch.Tensor, + scale: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.quant_fp8(input, scale) + + def make_scale(self, input: torch.Tensor): + normalized_group_shape = _normalize_quant_group_shape( + input, self.quant_key.scale.group_shape + ) + scale_shape = ( + input.shape[0] // normalized_group_shape[0], + input.shape[1] // normalized_group_shape[1], + ) + + return torch.empty(scale_shape, device=input.device, dtype=torch.float32) + + def inputs(self) -> list[torch.Tensor]: + input = self.empty(5, 16) + if self.quant_key.scale.static: + return [input, self.empty_f32(1, 1)] + + return [input] + + +class MatcherSiluAndMul(MatcherCustomOp): + def __init__(self, enabled: bool | None = None): + if enabled is None: + enabled = SiluAndMul.enabled() + super().__init__(enabled) + + def inputs(self) -> list[torch.Tensor]: + input = self.empty(5, 4) + return [input] + + def forward_custom( + self, + x: torch.Tensor, + ) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (d,) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + result = auto_functionalized(SILU_MUL_OP, result=out, input=x) + return result[1] + + def forward_native( + self, + x: torch.Tensor, + ) -> torch.Tensor: + return SiluAndMul.forward_native(x) diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index d3c437795fab..660fb9887e2c 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -3,7 +3,7 @@ import time -from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.config import CompilationConfig, CompilationMode, VllmConfig from vllm.logger import init_logger logger = init_logger(__name__) @@ -18,10 +18,11 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig): compilation_config: CompilationConfig = vllm_config.compilation_config path = vllm_config.compile_debug_dump_path() - if compilation_config.level == CompilationLevel.PIECEWISE and path: + if compilation_config.mode == CompilationMode.VLLM_COMPILE and path: import depyf path.mkdir(parents=True, exist_ok=True) + logger.debug("Dumping depyf output to %s", path) global context_manager context_manager = depyf.prepare_debug(path.as_posix()) context_manager.__enter__() @@ -29,9 +30,11 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig): def end_monitoring_torch_compile(vllm_config: VllmConfig): compilation_config: CompilationConfig = vllm_config.compilation_config - if compilation_config.level == CompilationLevel.PIECEWISE: - logger.info( - "torch.compile takes %.2f s in total", compilation_config.compilation_time + if compilation_config.mode == CompilationMode.VLLM_COMPILE: + logger.info_once( + "torch.compile takes %.2f s in total", + compilation_config.compilation_time, + scope="local", ) global context_manager if context_manager is not None: diff --git a/vllm/compilation/noop_elimination.py b/vllm/compilation/noop_elimination.py index 3d807ab3a6de..42b8d3daac98 100644 --- a/vllm/compilation/noop_elimination.py +++ b/vllm/compilation/noop_elimination.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Union import torch.fx from torch import SymInt @@ -81,77 +80,55 @@ def __call__(self, graph: torch.fx.Graph): graph.erase_node(input) count += 1 - # Case 2: remove this reshape if it produces the original shape - input, shape = node.args[:2] - input_shape = input.meta["val"].shape - if len(shape) != len(input_shape): - # Reshape changing rank, skip - continue - - if shape.count(-1) > 1: - # Invalid reshape args, skip - continue - - if self.reshape_all_dims_equivalent(shape, input_shape): - node.replace_all_uses_with(input) - graph.erase_node(node) - count += 1 - - elif is_func(node, torch.ops.aten.slice.Tensor): - # python slicing semantics are different from reshape - # Don't treat -1 as inferred dimension - input, dim_index, start, end = node.args[:4] + # remove reshape/slice if it produces the original shape + if is_func(node, torch.ops.aten.reshape.default) or is_func( + node, torch.ops.aten.slice.Tensor + ): + input = node.args[0] input_shape = input.meta["val"].shape output_shape = node.meta["val"].shape - - if output_shape == input_shape: + if self.all_dims_equivalent(input_shape, output_shape): node.replace_all_uses_with(input) graph.erase_node(node) count += 1 - elif is_func(node, torch.ops.aten.slice_scatter.default): base, view, dim_index, start, end = node.args[:5] base_shape = base.meta["val"].shape view_shape = view.meta["val"].shape - if base_shape == view_shape: + if self.all_dims_equivalent(base_shape, view_shape): node.replace_all_uses_with(view) graph.erase_node(node) count += 1 logger.debug("Removed %s no-op reshapes and slices", count) - # ---------------------- Reshape helpers ---------------------- - def reshape_dims_equivalent( - self, dim: Union[int, torch.fx.Node], i_dim: Union[int, SymInt] - ) -> bool: + # ---------------------- Shape comparison helpers ---------------------- + def dims_equivalent(self, dim: int | SymInt, i_dim: int | SymInt) -> bool: """ This function checks if two dimensions are equivalent. :param dim: The dimension arg to reshape/slice :param i_dim: The corresponding dimension in the input tensor :return: Are the dimensions equivalent? - There are three cases in which the dimensions are equivalent: + There are two cases in which the dimensions are equivalent: 1. The dimensions are equal (both integers) - 2. The reshape dimension is -1 (i.e. inferred) - 3. The dimensions both correspond to the same SymInt - - While case 2 does not guarantee the dimensions are equal, - they are equal if all other dimensions are equal. - - In case 3, the reshape dimension is a torch.fx.Node, - and its value is a SymInt. That value is equal to the - input dimension. + 2. The dimensions both correspond to the same SymInt """ - # Case 1 and 2 - if dim == i_dim or dim == -1: - return True - # Case 3 - return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim - - def reshape_all_dims_equivalent( - self, - dims: Iterable[Union[int, torch.fx.Node]], - i_dims: Iterable[Union[int, SymInt]], + # Case 1 + if isinstance(i_dim, int) and isinstance(dim, int): + return dim == i_dim + # Case 2 + if isinstance(i_dim, SymInt) and isinstance(dim, SymInt): + return dim == i_dim + return False + + def all_dims_equivalent( + self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt] ) -> bool: - return all(self.reshape_dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims)) + dims_ = list(dims) + i_dims_ = list(i_dims) + if len(dims_) != len(i_dims_): + # Different ranks can't be equivalent + return False + return all(self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims)) diff --git a/vllm/compilation/partition_rules.py b/vllm/compilation/partition_rules.py new file mode 100644 index 000000000000..cea4f9a81637 --- /dev/null +++ b/vllm/compilation/partition_rules.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import contextlib +import logging +from typing import TYPE_CHECKING + +from torch._library.utils import lookup_op + +from vllm.logger import init_logger + +if TYPE_CHECKING: + import torch + +logger = init_logger(__name__) + + +def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]: + """Resolve operator names to OpOverload objects. + + Skips operators that fail to resolve (e.g., operators not registered or + model-specific operators not present in the current model). + + Note: Users should inspect the operator graph before lowering and ensure + the specified operators are present in the final graph. Built-in PyTorch + operators (aten::*, torch::*) may be decomposed, fused, or transformed + during Inductor's compilation passes, so use them with caution. + + Args: + op_names: List of operator names in PyTorch format + (e.g., "vllm::unified_attention") + + Returns: + List of successfully resolved operator overloads + """ + resolved = [] + for op_name in op_names: + try: + resolved.append(lookup_op(op_name)) + except Exception: + # Skip operators that don't exist (e.g., model-specific ops) + # Do not warn for attention ops, warn for others + # (most likely manually specified) + from vllm.config import CompilationConfig + + logger.log( + logging.DEBUG + if op_name in CompilationConfig._attention_ops + else logging.WARNING, + "Failed to resolve operator for CUDAGraph partition: %s", + op_name, + ) + continue + + return resolved + + +@contextlib.contextmanager +def inductor_partition_rule_context(overloads: list["torch._ops.OpOverload"]): + """Context manager to temporarily register Inductor partition rules. + + Registers custom partition rules for specified operators, forcing the + Inductor scheduler to partition the graph at these operators. The rules + are automatically restored to their previous state on exit. + + Note: Callers should use resolve_defined_ops() to convert operator names + to OpOverload objects before calling this function. + + Args: + overloads: List of resolved operator overload objects. + """ + if not overloads: + logger.debug("No partition ops provided; skipping rule registration.") + yield + return + + from torch._inductor.scheduler import ( # type: ignore + _custom_should_partition_fns, + register_should_partition_rule, + ) + + def _always_partition(*_args, **_kwargs): + return True + + # Save current state before registering + saved_rules = _custom_should_partition_fns.copy() + + for overload in overloads: + register_should_partition_rule( + overload, + _always_partition, + ) + + logger.debug("Registered inductor partition rules for %d operators", len(overloads)) + + try: + yield + finally: + # Clear and restore previous state + _custom_should_partition_fns.clear() + _custom_should_partition_fns.update(saved_rules) + logger.debug("Restored previous partition rules state.") diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index e323fa1f7734..3bc35a8f7198 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -5,10 +5,10 @@ from torch import fx as fx from vllm import envs -from vllm.config import VllmConfig +from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import set_env_var +from vllm.utils.system_utils import set_env_var from .post_cleanup import PostCleanupPass from .vllm_inductor_pass import VllmInductorPass @@ -71,9 +71,11 @@ def __call__(self, graph: fx.Graph): shape = get_pass_context().runtime_shape for pass_ in self.passes: - if pass_.is_applicable_for_shape(shape): + if pass_.is_applicable(shape): pass_(graph) VllmInductorPass.dump_prefix += 1 + else: + logger.debug("Skipping %s with shape %s", pass_, shape) # post-cleanup goes before fix_functionalization # because it requires a functional graph @@ -86,27 +88,51 @@ def __call__(self, graph: fx.Graph): def configure(self, config: VllmConfig): self.pass_config = config.compilation_config.pass_config - if self.pass_config.enable_noop: - self.passes += [NoOpEliminationPass(config)] - if self.pass_config.enable_sequence_parallelism: - self.passes += [SequenceParallelismPass(config)] - if self.pass_config.enable_async_tp: - self.passes += [AsyncTPPass(config)] - - if self.pass_config.enable_fi_allreduce_fusion: - self.passes += [AllReduceFusionPass(config)] - - if self.pass_config.enable_fusion: - self.passes += [RMSNormQuantFusionPass(config)] - self.passes += [ActivationQuantFusionPass(config)] - - if self.pass_config.enable_attn_fusion: - self.passes += [AttnFusionPass(config)] - - # needs a functional graph - self.post_cleanup = PostCleanupPass(config) - self.fix_functionalization = FixFunctionalizationPass(config) + # Set the current vllm config to allow tracing CustomOp instances + with set_current_vllm_config(config, check_compile=False): + if self.pass_config.enable_noop: + self.passes += [NoOpEliminationPass(config)] + + if self.pass_config.enable_sequence_parallelism: + self.passes += [SequenceParallelismPass(config)] + if self.pass_config.enable_async_tp: + self.passes += [AsyncTPPass(config)] + + if self.pass_config.enable_fi_allreduce_fusion: + self.passes += [AllReduceFusionPass(config)] + + if self.pass_config.enable_fusion: + self.passes += [RMSNormQuantFusionPass(config)] + self.passes += [ActivationQuantFusionPass(config)] + + if self.pass_config.enable_attn_fusion: + self.passes += [AttnFusionPass(config)] + + # needs a functional graph + self.post_cleanup = PostCleanupPass(config) + self.fix_functionalization = FixFunctionalizationPass(config) + + # [HACK: Bug with Inductor graph partition and torch.compile cache] + # In PyTorch 2.9, torch.compile has a bug where the graph + # partition is not taken into account during caching. + # Because vLLM's Mode.VLLM_COMPILE is the only mode that uses + # Inductor graph partition, and VLLM_COMPILE implies there + # is a PostGradPassManager, we put the list of operators to graph + # partition into the PostGradPassManager's uuid (which + # then gets incorporated into Inductor's FX graph cache key). + # Remove this hack whenever torch.compile fixes it. + + # This is the list of operators that vLLM asks Inductor to split. + self.inductor_splitting_ops = [] + if ( + config.compilation_config.use_inductor_graph_partition + and config.compilation_config.splitting_ops is not None + ): + # Sort them so we're not dependent on the ordering. + self.inductor_splitting_ops = sorted( + config.compilation_config.splitting_ops + ) def add(self, pass_: InductorPass): assert isinstance(pass_, InductorPass) @@ -118,8 +144,16 @@ def uuid(self): affects compilation caching. Its uuid depends on the UUIDs of all dependent passes and the pass config. See InductorPass for more info. """ - state = {"pass_config": self.pass_config.uuid(), "passes": []} + state = { + "pass_config": self.pass_config.uuid(), + "passes": [], + "inductor_splitting_ops": [], + } for pass_ in self.passes: state["passes"].append(pass_.uuid()) state["passes"].append(self.fix_functionalization.uuid()) + + # See [HACK: Bug with Inductor graph partition and torch.compile cache] + state["inductor_splitting_ops"].extend(self.inductor_splitting_ops) + return InductorPass.hash_dict(state) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 61551766a1c5..2931580afbbb 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -2,7 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import torch.fx as fx diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 2bc705c3b9a9..31624a8fdcc0 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch import torch._inductor.pattern_matcher as pm @@ -27,7 +26,7 @@ def __init__( epsilon: float, dtype: torch.dtype, device: str, - quant_op: Optional[torch._ops.OpOverload] = None, + quant_op: torch._ops.OpOverload | None = None, **kwargs, ): self.epsilon = epsilon @@ -110,7 +109,7 @@ def __init__( epsilon: float, dtype: torch.dtype, device: str, - quant_op: Optional[torch._ops.OpOverload] = None, + quant_op: torch._ops.OpOverload | None = None, **kwargs, ): super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs) @@ -483,7 +482,25 @@ def __init__(self, config: VllmConfig): ).register(self.patterns) self.dump_patterns(config, self.patterns) - def is_applicable_for_shape(self, shape: Optional[int]) -> bool: + def is_applicable(self, shape: int | None) -> bool: + # When sequence parallelism is enabled, the residual tensor from RMSNorm + # needs to be split along the sequence dimension. However, this dimension + # is symbolic during piecewise compilation, and splitting symbolic shapes + # is not supported. + # + # This pass is therefore only applied when the sequence dimension is + # concrete: + # 1. In full-graph compilation mode (no Dynamo splitting ops are used). + # For this case we always pad num_tokens to be a multiple of + # tensor_parallel_size, so there's no need to check shape % tp_size == 0. + # 2. For specific shape provided during compilation (e.g., from + # `compile_sizes`), which must be divisible by the tensor-parallel + # size. + if ( + not self.compilation_config.splitting_ops + or self.compilation_config.use_inductor_graph_partition + ): + return True tp_size = get_tensor_model_parallel_world_size() return shape is not None and shape % tp_size == 0 diff --git a/vllm/compilation/torch25_custom_graph_pass.py b/vllm/compilation/torch25_custom_graph_pass.py index ea8b56cf9d6a..1031856cdf00 100644 --- a/vllm/compilation/torch25_custom_graph_pass.py +++ b/vllm/compilation/torch25_custom_graph_pass.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any import torch @@ -23,7 +23,7 @@ def __call__(self, graph: torch.fx.graph.Graph) -> None: """ @abstractmethod - def uuid(self) -> Optional[Any]: + def uuid(self) -> Any | None: """ Return an ID to uniquely identify your custom pass implementation. Return None to skip inductor code caching entirely. diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index 5aa08220bc2d..08721e3ae4a2 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -3,7 +3,8 @@ import functools import operator import time -from typing import ClassVar, Optional +from dataclasses import dataclass +from typing import ClassVar import regex as re import torch @@ -18,16 +19,28 @@ logger = init_logger(__name__) +@dataclass +class InductorCompilationConfig: + splitting_ops: list[str] | None = None + use_inductor_graph_partition: bool = False + + class VllmInductorPass(InductorPass): """ An inductor pass with access to vLLM PassConfig. It provides timing, logging, and dumping utilities. """ - dump_prefix: ClassVar[Optional[int]] = None + dump_prefix: ClassVar[int | None] = None """Keep track of pass index for debug dump ordering.""" def __init__(self, config: VllmConfig): + # Get only the necessary CompilationConfig for the inductor pass, since + # full `CompilationConfig` contains pointer to model which is unsafe. + self.compilation_config = InductorCompilationConfig( + splitting_ops=config.compilation_config.splitting_ops, + use_inductor_graph_partition=config.compilation_config.use_inductor_graph_partition, + ) self.pass_config = config.compilation_config.pass_config self.model_dtype = config.model_config.dtype if config.model_config else None self.device = config.device_config.device if config.device_config else None @@ -101,7 +114,7 @@ def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass): debug_dump_path.mkdir(parents=True, exist_ok=True) - from vllm.utils import unique_filepath + from vllm.utils.system_utils import unique_filepath file_path = unique_filepath( lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py" @@ -115,7 +128,8 @@ def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass): f" please add to dump_patterns if there are any errors.\n\n" f"from torch._higher_order_ops.auto_functionalize import " f"auto_functionalized as auto_functionalized\n" - f"from torch._inductor.pattern_matcher import *", + f"from torch._inductor.pattern_matcher import *\n" + f"vllm = torch.ops.vllm", file=f, ) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 71a4e1745d4e..4b10c85209f6 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -4,13 +4,14 @@ import os import sys from abc import abstractmethod +from collections.abc import Callable from contextlib import contextmanager from types import CodeType -from typing import Callable, Optional import torch -from vllm.config import CompilationLevel, CUDAGraphMode, get_current_vllm_config +import vllm.envs as envs +from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config from vllm.logger import init_logger logger = init_logger(__name__) @@ -30,7 +31,7 @@ class TorchCompileWrapperWithCustomDispatcher: """ def __init__( - self, compiled_callable: Optional[Callable] = None, compilation_level: int = 0 + self, compiled_callable: Callable | None = None, compilation_mode: int = 0 ): vllm_config = get_current_vllm_config() self.vllm_config = vllm_config @@ -44,6 +45,19 @@ def __init__( options = ( get_current_vllm_config().compilation_config.inductor_compile_config ) + if envs.VLLM_USE_AOT_COMPILE: + options = options or {} + # This effectively drop all the guards. + # We need this because bytecode hook is not used any more to + # drop guards in the AOT compile mode. + options["guard_filter_fn"] = lambda guards: [False for _ in guards] + if hasattr(torch._dynamo.config, "enable_aot_compile"): + torch._dynamo.config.enable_aot_compile = True + else: + msg = "torch._dynamo.config.enable_aot_compile is not " + msg += "available. AOT compile is disabled and please " + msg += "upgrade PyTorch version to use AOT compile." + logger.warning(msg) compiled_callable = torch.compile( self.forward, fullgraph=True, backend=backend, options=options @@ -58,11 +72,20 @@ def __init__( # subclasses can use this to switch between the custom dispatcher # and the default Dynamo guard mechanism. self.use_custom_dispatcher: bool = ( - compilation_level >= CompilationLevel.DYNAMO_ONCE + compilation_mode >= CompilationMode.DYNAMO_TRACE_ONCE ) + def aot_compile(self, *args, **kwargs): + if not hasattr(self.compiled_callable, "aot_compile"): + raise RuntimeError( + "aot_compile is not supported by the current configuration. " + + "Please make sure torch.compile is enabled with the latest " + + f"version of PyTorch (current using torch: {torch.__version__})" + ) + return self.compiled_callable.aot_compile((args, kwargs)) + def __call__(self, *args, **kwargs): - """Implement the dispatch logic here, beyond the torch.compile level. + """Implement the dispatch logic here, beyond the torch.compile mode. NOTE: this function can have additional arguments beyond the forward method, for directly dispatching to the compiled code. """ diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 7c5052c822f8..7f1cc5202420 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1,42 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.config.cache import ( - BlockSize, - CacheConfig, - CacheDType, - MambaDType, - PrefixCachingHashAlgo, -) +from vllm.config.cache import CacheConfig from vllm.config.compilation import ( CompilationConfig, - CompilationLevel, + CompilationMode, CUDAGraphMode, PassConfig, ) -from vllm.config.device import Device, DeviceConfig +from vllm.config.device import DeviceConfig from vllm.config.kv_events import KVEventsConfig from vllm.config.kv_transfer import KVTransferConfig from vllm.config.load import LoadConfig from vllm.config.lora import LoRAConfig from vllm.config.model import ( - ConvertOption, - HfOverrides, - LogprobsMode, ModelConfig, - ModelDType, - ModelImpl, - RunnerOption, - TaskOption, - TokenizerMode, iter_architecture_defaults, try_match_architecture_defaults, ) -from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig -from vllm.config.observability import DetailedTraceModules, ObservabilityConfig -from vllm.config.parallel import DistributedExecutorBackend, EPLBConfig, ParallelConfig +from vllm.config.multimodal import MultiModalConfig +from vllm.config.observability import ObservabilityConfig +from vllm.config.parallel import EPLBConfig, ParallelConfig from vllm.config.pooler import PoolerConfig -from vllm.config.scheduler import RunnerType, SchedulerConfig, SchedulerPolicy +from vllm.config.scheduler import SchedulerConfig from vllm.config.speculative import SpeculativeConfig from vllm.config.speech_to_text import SpeechToTextConfig from vllm.config.structured_outputs import StructuredOutputsConfig @@ -56,20 +42,17 @@ set_current_vllm_config, ) +# __all__ should only contain classes and functions. +# Types and globals should be imported from their respective modules. __all__ = [ # From vllm.config.cache - "BlockSize", "CacheConfig", - "CacheDType", - "MambaDType", - "PrefixCachingHashAlgo", # From vllm.config.compilation "CompilationConfig", - "CompilationLevel", + "CompilationMode", "CUDAGraphMode", "PassConfig", # From vllm.config.device - "Device", "DeviceConfig", # From vllm.config.kv_events "KVEventsConfig", @@ -80,34 +63,20 @@ # From vllm.config.lora "LoRAConfig", # From vllm.config.model - "ConvertOption", - "HfOverrides", - "LogprobsMode", "ModelConfig", - "ModelDType", - "ModelImpl", - "RunnerOption", - "TaskOption", - "TokenizerMode", "iter_architecture_defaults", "try_match_architecture_defaults", # From vllm.config.multimodal - "MMCacheType", - "MMEncoderTPMode", "MultiModalConfig", # From vllm.config.observability - "DetailedTraceModules", "ObservabilityConfig", # From vllm.config.parallel - "DistributedExecutorBackend", "EPLBConfig", "ParallelConfig", # From vllm.config.pooler "PoolerConfig", # From vllm.config.scheduler - "RunnerType", "SchedulerConfig", - "SchedulerPolicy", # From vllm.config.speculative "SpeculativeConfig", # From vllm.config.speech_to_text diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 519e3d8b9bde..cf2977622a0b 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -3,16 +3,15 @@ import hashlib from dataclasses import field -from typing import TYPE_CHECKING, Any, Literal, Optional, get_args +from typing import TYPE_CHECKING, Any, Literal -from pydantic import SkipValidation, model_validator +from pydantic import Field, SkipValidation, field_validator from pydantic.dataclasses import dataclass -from typing_extensions import Self -import vllm.envs as envs from vllm.config.utils import config from vllm.logger import init_logger -from vllm.utils import GiB_bytes, get_cpu_memory +from vllm.utils.mem_constants import GiB_bytes +from vllm.utils.mem_utils import get_cpu_memory if TYPE_CHECKING: from vllm.config.parallel import ParallelConfig @@ -21,7 +20,7 @@ logger = init_logger(__name__) -BlockSize = Literal[1, 8, 16, 32, 64, 128] +BlockSize = Literal[1, 8, 16, 32, 64, 128, 256] CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] MambaDType = Literal["auto", "float32"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] @@ -39,7 +38,7 @@ class CacheConfig: This config has no static default. If left unspecified by the user, it will be set in `Platform.check_and_update_config()` based on the current platform.""" - gpu_memory_utilization: float = 0.9 + gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1) """The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory utilization. If unspecified, will use the default value of 0.9. This is a @@ -47,7 +46,7 @@ class CacheConfig: not matter if you have another vLLM instance running on the same GPU. For example, if you have two vLLM instances running on the same GPU, you can set the GPU memory utilization to 0.5 for each instance.""" - swap_space: float = 4 + swap_space: float = Field(default=4, ge=0) """Size of the CPU swap space per GPU (in GiB).""" cache_dtype: CacheDType = "auto" """Data type for kv cache storage. If "auto", will use model data type. @@ -60,20 +59,20 @@ class CacheConfig: is_attention_free: bool = False """Whether the model is attention-free. This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" - num_gpu_blocks_override: Optional[int] = None + num_gpu_blocks_override: int | None = None """Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks` if specified. Does nothing if `None`. Used for testing preemption.""" - sliding_window: Optional[int] = None + sliding_window: int | None = None """Sliding window size for the KV cache. This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" - enable_prefix_caching: Optional[bool] = None + enable_prefix_caching: bool | None = None """Whether to enable prefix caching. Enabled by default for V1.""" prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256" """Set the hash algorithm for prefix caching:\n - "sha256" uses Pickle for object serialization before hashing.\n - "sha256_cbor" provides a reproducible, cross-language compatible hash. It serializes objects using canonical CBOR and hashes them with SHA-256.""" - cpu_offload_gb: float = 0 + cpu_offload_gb: float = Field(default=0, ge=0) """The space in GiB to offload to CPU, per GPU. Default is 0, which means no offloading. Intuitively, this argument can be seen as a virtual way to increase the GPU memory size. For example, if you have one 24 GB GPU and @@ -86,12 +85,12 @@ class CacheConfig: """This enables dynamic calculation of `k_scale` and `v_scale` when kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model checkpoint if available. Otherwise, the scales will default to 1.0.""" - cpu_kvcache_space_bytes: Optional[int] = None + cpu_kvcache_space_bytes: int | None = None """(CPU backend only) CPU key-value cache space.""" - mamba_page_size_padded: Optional[int] = None + mamba_page_size_padded: int | None = None """ Optional override for mamba page size; used by hybrid mamba/attention models to ensure exact alignment with attention page size.""" - mamba_block_size: Optional[int] = None + mamba_block_size: int | None = None """Size of a contiguous cache block in number of tokens for mamba cache.""" mamba_cache_dtype: MambaDType = "auto" """The data type to use for the Mamba cache (both the conv as well as the @@ -103,9 +102,9 @@ class CacheConfig: for the ssm state will be determined by mamba_cache_dtype.""" # Will be set after profiling. - num_gpu_blocks: Optional[int] = field(default=None, init=False) + num_gpu_blocks: int | None = field(default=None, init=False) """The number of blocks to allocate for GPU memory.""" - num_cpu_blocks: Optional[int] = field(default=None, init=False) + num_cpu_blocks: int | None = field(default=None, init=False) """The number of blocks to allocate for CPU memory.""" kv_sharing_fast_prefill: bool = False @@ -118,7 +117,7 @@ class CacheConfig: necessary for implementing this optimization in some models (e.g. Gemma3n) """ - kv_cache_memory_bytes: Optional[int] = None + kv_cache_memory_bytes: int | None = None """Size of KV Cache per GPU in bytes. By default, this is set to None and vllm can automatically infer the kv cache size based on gpu_memory_utilization. However, users may want to manually specify @@ -147,74 +146,33 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str - def __post_init__(self) -> None: - self.swap_space_bytes = self.swap_space * GiB_bytes - - self._verify_cache_dtype() - self._verify_prefix_caching() - def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus # metrics info return {key: str(value) for key, value in self.__dict__.items()} - @model_validator(mode="after") - def _verify_args(self) -> Self: - if self.cpu_offload_gb < 0: - raise ValueError( - f"CPU offload space must be non-negative, but got {self.cpu_offload_gb}" - ) - - if self.gpu_memory_utilization > 1.0: - raise ValueError( - "GPU memory utilization must be less than 1.0. Got " - f"{self.gpu_memory_utilization}." - ) - - return self - - def _verify_cache_dtype(self) -> None: - if self.cache_dtype == "auto": - pass - elif self.cache_dtype in get_args(CacheDType): - if self.cache_dtype.startswith("fp8"): - logger.info( - "Using fp8 data type to store kv cache. It reduces the GPU " - "memory footprint and boosts the performance. " - "Meanwhile, it may cause accuracy drop without a proper " - "scaling factor." - ) - else: - raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") - - def _verify_prefix_caching(self) -> None: - if not self.enable_prefix_caching: - return - - if self.sliding_window is not None and not envs.VLLM_USE_V1: - raise NotImplementedError( - "Prefix caching is not supported with sliding window. " - "Run with --disable-sliding-window to use prefix caching." - ) - - if self.enable_prefix_caching and self.prefix_caching_hash_algo not in get_args( - PrefixCachingHashAlgo - ): - raise ValueError( - "Unknown prefix caching hash algorithm: " - f"{self.prefix_caching_hash_algo}. Must be one of " - f"{get_args(PrefixCachingHashAlgo)}." + @field_validator("cache_dtype", mode="after") + @classmethod + def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType: + if cache_dtype.startswith("fp8"): + logger.info( + "Using fp8 data type to store kv cache. It reduces the GPU " + "memory footprint and boosts the performance. " + "Meanwhile, it may cause accuracy drop without a proper " + "scaling factor." ) + return cache_dtype def verify_with_parallel_config( self, parallel_config: ParallelConfig, ) -> None: + swap_space_bytes = self.swap_space * GiB_bytes total_cpu_memory = get_cpu_memory() # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel # group are in the same node. However, the GPUs may span multiple nodes. num_gpus_per_node = parallel_config.tensor_parallel_size - cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node + cpu_memory_usage = swap_space_bytes * num_gpus_per_node msg = ( f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the " diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 9346bfa6307a..c24a94091be4 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -4,9 +4,10 @@ import enum import hashlib from collections import Counter +from collections.abc import Callable from dataclasses import asdict, field from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union +from typing import TYPE_CHECKING, Any, ClassVar from pydantic import TypeAdapter, field_validator from pydantic.dataclasses import dataclass @@ -14,7 +15,9 @@ from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.config.utils import config from vllm.logger import init_logger -from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname +from vllm.platforms import current_platform +from vllm.utils.import_utils import resolve_obj_by_qualname +from vllm.utils.torch_utils import is_torch_equal_or_newer if TYPE_CHECKING: from vllm.config import VllmConfig @@ -24,12 +27,20 @@ logger = init_logger(__name__) -class CompilationLevel: - # constants for the levels of the compilation process - NO_COMPILATION = 0 - DYNAMO_AS_IS = 1 - DYNAMO_ONCE = 2 - PIECEWISE = 3 +class CompilationMode: + """The compilation approach used for torch.compile-based compilation of the + model.""" + + NONE = 0 + """No torch.compile compilation is applied, model runs in fully eager pytorch mode. + The model runs as-is.""" + STOCK_TORCH_COMPILE = 1 + """The standard `torch.compile` compilation pipeline.""" + DYNAMO_TRACE_ONCE = 2 + """Single Dynamo trace through the model, avoiding recompilation.""" + VLLM_COMPILE = 3 + """Custom vLLM Inductor-based backend with caching, piecewise compilation, + shape specialization, and custom passes.""" class CUDAGraphMode(enum.Enum): @@ -50,11 +61,14 @@ def decode_mode(self) -> "CUDAGraphMode": def mixed_mode(self) -> "CUDAGraphMode": return CUDAGraphMode(self.value[1]) if self.separate_routine() else self + def has_mode(self, mode: "CUDAGraphMode") -> bool: + assert not mode.separate_routine() + if self.separate_routine(): + return mode.value in self.value + return self == mode + def requires_piecewise_compilation(self) -> bool: - return ( - self.decode_mode() == CUDAGraphMode.PIECEWISE - or self.mixed_mode() == CUDAGraphMode.PIECEWISE - ) + return self.has_mode(CUDAGraphMode.PIECEWISE) def max_cudagraph_mode(self) -> "CUDAGraphMode": return CUDAGraphMode(max(self.value)) if self.separate_routine() else self @@ -129,7 +143,7 @@ class CompilationConfig: """Configuration for compilation. It has three parts: - Top-level Compilation control: - - [`level`][vllm.config.CompilationConfig.level] + - [`mode`][vllm.config.CompilationConfig.mode] - [`debug_dump_path`][vllm.config.CompilationConfig.debug_dump_path] - [`cache_dir`][vllm.config.CompilationConfig.cache_dir] - [`backend`][vllm.config.CompilationConfig.backend] @@ -140,6 +154,8 @@ class CompilationConfig: - [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode] - [`cudagraph_capture_sizes`] [vllm.config.CompilationConfig.cudagraph_capture_sizes] + - [`max_cudagraph_capture_size`] + [vllm.config.CompilationConfig.max_cudagraph_capture_size] - [`cudagraph_num_of_warmups`] [vllm.config.CompilationConfig.cudagraph_num_of_warmups] - [`cudagraph_copy_inputs`] @@ -165,22 +181,34 @@ class CompilationConfig: """ # Top-level Compilation control - level: Optional[int] = None - """The level of compilation: - - - None: If None, we will select the default compilation level. - For V1 engine this is 3, for V0 engine this is 0. - - 0: no compilation. - - 1: dynamo as is. - - 2: dynamo once. - - 3: piecewise compilation.""" - debug_dump_path: Optional[Path] = None + level: int | None = None + """ + Level is deprecated and will be removed in the next release, + either 0.12.0 or 0.11.2 whichever is soonest. + Please use mode. Currently all levels are mapped to mode. + """ + # Top-level Compilation control + mode: int | None = None + """The compilation approach used for torch.compile-based compilation of the + model. + + - None: If None, we will select the default compilation mode. + For V1 engine this is 3. + - 0: NONE: No torch.compile compilation is applied, model runs in fully + eager pytorch mode. The model runs as-is. + - 1: STOCK_TORCH_COMPILE: The standard `torch.compile` compilation pipeline. + - 2: DYNAMO_TRACE_ONCE: Single Dynamo trace through the model, avoiding + recompilation by removing guards. + Requires no dynamic-shape-dependent control-flow. + - 3: VLLM_COMPILE: Custom vLLM Inductor-based backend with caching, + piecewise compilation, shape specialization, and custom passes.""" + debug_dump_path: Path | None = None """The path to dump the debug information.""" cache_dir: str = "" """The directory to store the compiled graph, to accelerate Inductor compilation. By default, it will use model-related information to generate a cache directory.""" - backend: str = "inductor" + backend: str = "" """The backend for compilation. It needs to be a string: - "" (empty string): use the default backend ("inductor" on CUDA-alike @@ -190,13 +218,14 @@ class CompilationConfig: backend function. We use string to avoid serialization issues when using compilation in a - distributed setting. When the compilation level is 1 or 2, the backend is + distributed setting. When the compilation mode is 1 or 2, the backend is used for the compilation directly (it sees the whole graph). When the - compilation level is 3, the backend is used for the piecewise compilation + compilation mode is 3, the backend is used for the piecewise compilation (it sees a part of the graph). The backend can not be custom for compilation - level 3. Furthermore, compilation is only piecewise if splitting ops is set - accordingly and use_inductor_cudagraphs_partition is off. Note that the - default options for splitting ops are sufficient for piecewise compilation. + mode 3, i.e. the backend must be either eager or inductor. Furthermore, + compilation is only piecewise if splitting ops is set accordingly and + use_inductor_graph_partition is off. Note that the default options for + splitting ops are sufficient for piecewise compilation. """ custom_ops: list[str] = field(default_factory=list) """Fine-grained control over which custom ops to enable/disable. Use 'all' @@ -208,18 +237,33 @@ class CompilationConfig: - 'none,+op1,+op2' to enable only op1 and op2 By default, all custom ops are enabled when running without Inductor and - disabled when running with Inductor: level>=PIECEWISE and use_inductor=True. + disabled when running with Inductor: mode>=VLLM_COMPILE and use_inductor=True. Inductor generates (fused) Triton kernels for disabled custom ops.""" - splitting_ops: Optional[list[str]] = None - """A list of ops to split the full graph into subgraphs, used in piecewise - compilation.""" + splitting_ops: list[str] | None = None + """A list of ops to exclude from cudagraphs, used in piecewise compilation. + + The behavior depends on use_inductor_graph_partition: + + - When use_inductor_graph_partition=False (default): + These ops are used for Dynamo FX-level graph splitting. The graph is + split at these ops before Inductor compilation, creating separate + subgraphs for cudagraph capture. + + - When use_inductor_graph_partition=True: + These ops are used to register Inductor partition rules. The graph + partitioning happens at Inductor codegen time after all passes and + fusions are finished, allowing compilation and custom passes to operate + on the full graph while still excluding these ops from cudagraphs. + + If None, defaults to attention ops for piecewise cudagraphs. + If empty list [], no ops are excluded (suitable for full cudagraphs).""" # Inductor capture - use_inductor: Optional[bool] = None + use_inductor: bool | None = None """ Whether to use inductor compilation. - This flag is deprecated and will be removed. + This flag is deprecated and will be removed in the next release 0.12.0. Please use the 'backend' option instead. - False: inductor compilation is not used. graph runs in eager @@ -228,12 +272,12 @@ class CompilationConfig: One graph for symbolic shape and one graph per size in compile_sizes are compiled using configurations in inductor_compile_config. - This setting is ignored if level<PIECEWISE. + This setting is ignored if mode<VLLM_COMPILE. For future compatibility: If use_inductor is True, backend="inductor" otherwise backend="eager". """ - compile_sizes: Optional[list[Union[int, str]]] = None + compile_sizes: list[int | str] | None = None """Sizes to compile for inductor. In addition to integers, it also supports "cudagraph_capture_sizes" to specify the sizes for cudagraph capture.""" @@ -248,7 +292,7 @@ class CompilationConfig: constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`.""" # CudaGraph compilation - cudagraph_mode: Optional[CUDAGraphMode] = None + cudagraph_mode: CUDAGraphMode | None = None """ The mode of the cudagraph: @@ -278,32 +322,30 @@ class CompilationConfig: Currently, the cudagraph mode is only used for the v1 engine. Note that the cudagraph logic is generally orthogonal to the compilation logic. While piecewise cudagraphs require piecewise - compilation (level=PIECEWISE and non-empty splitting_ops), full + compilation (mode=VLLM_COMPILE and non-empty splitting_ops), full cudagraphs are supported with and without compilation. Warning: This flag is new and subject to change in addition more modes may be added. """ use_cudagraph: bool = True - """Whether to use cudagraph inside compilation. - - False: cudagraph inside compilation is not used. + """Whether to use cudagraph inside compilation: + + - False: cudagraph inside compilation is not used.\n - True: cudagraph inside compilation is used. It requires that all input buffers have fixed addresses, and all splitting ops write their outputs to input buffers. - In the vLLM V1 Engine, this flag only applies for - CompilationLevel.PIECEWISE (aka -O3). - Note that this is orthogonal to the cudagraph capture logic - outside of compilation. + Warning: This flag is deprecated and will be removed in the next major or - minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=PIECEWISE - instead. + minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=FULL_AND + _PIECEWISE instead. """ cudagraph_num_of_warmups: int = 0 """Number of warmup runs for cudagraph. It means the first several runs will be treated as warmup runs. Only after that, the execution will be recorded, and the recorded cudagraph will be used for subsequent runs.""" - cudagraph_capture_sizes: Optional[list[int]] = None + cudagraph_capture_sizes: list[int] | None = None """Sizes to capture cudagraph. - None (default): capture sizes are inferred from vllm config. - list[int]: capture sizes are specified as given.""" @@ -315,7 +357,7 @@ class CompilationConfig: internally managed buffer. Default is False. Note that this flag is only effective when cudagraph_mode is PIECEWISE. """ - full_cuda_graph: Optional[bool] = False + full_cuda_graph: bool | None = False """whether to use a full cuda graph for the entire forward pass rather than splitting certain operations such as attention into subgraphs. Thus this flag cannot be used together with splitting_ops. This may provide @@ -324,6 +366,14 @@ class CompilationConfig: minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode= FULL_AND_PIECEWISE instead. """ + cudagraph_specialize_lora: bool = True + """Whether to create separate cuda graphs for cases with and without active + LoRA adapters. When set to False, the LoRA-enabled cuda graph will be used + for all cases, incurring the overhead of running LoRA ops even when no + adapters are active. Setting this to True will remove this overhead at the + cost of increased startup time and slightly higher memory usage. + When `enable_lora` is False, this option has no effect. + """ use_inductor_graph_partition: bool = False """Use inductor graph partition to split the graph at cudagraph_unsafe ops. @@ -348,8 +398,22 @@ class CompilationConfig: pass_config: PassConfig = field(default_factory=PassConfig) """Custom inductor passes, see PassConfig for more details""" - max_capture_size: int = field(default=None, init=False) # type: ignore - """not configurable, computed after init""" + max_cudagraph_capture_size: int | None = field(default=None) + """The maximum cudagraph capture size. + + If cudagraph_capture_sizes is specified, this will be set to the largest + size in that list (or checked for consistency if specified). If + cudagraph_capture_sizes is not specified, the list of sizes is generated + automatically following the pattern: + + [1, 2, 4] + list(range(8, 256, 8)) + list( + range(256, max_cudagraph_capture_size + 1, 16)) + + If not specified, max_cudagraph_capture_size is set to min(max_num_seqs*2, + 512) by default. This voids OOM in tight memory scenarios with small + max_num_seqs, and prevents capture of many large graphs (>512) that would + greatly increase startup time with limited performance benefit. + """ local_cache_dir: str = field(default=None, init=False) # type: ignore """local cache dir for each rank""" bs_to_padded_graph_size: list[int] = field( @@ -358,7 +422,7 @@ class CompilationConfig: ) """optimization: Intuitively, bs_to_padded_graph_size should be dict[int, int]. - since we know all keys are in a range [0, max_capture_size], + since we know all keys are in a range [0, max_cudagraph_capture_size], we can optimize it to list[int] for better lookup performance.""" # keep track of enabled and disabled custom ops @@ -377,16 +441,19 @@ class CompilationConfig: model code, e.g., Attention, FusedMOE when dp_size>1.""" # Attention ops; used for piecewise cudagraphs + # Use PyTorch operator format: "namespace::name" _attention_ops: ClassVar[list[str]] = [ - "vllm.unified_attention", - "vllm.unified_attention_with_output", - "vllm.mamba_mixer2", - "vllm.mamba_mixer", - "vllm.short_conv", - "vllm.linear_attention", - "vllm.plamo2_mamba_mixer", - "vllm.gdn_attention", - "vllm.sparse_attn_indexer", + "vllm::unified_attention", + "vllm::unified_attention_with_output", + "vllm::unified_mla_attention", + "vllm::unified_mla_attention_with_output", + "vllm::mamba_mixer2", + "vllm::mamba_mixer", + "vllm::short_conv", + "vllm::linear_attention", + "vllm::plamo2_mamba_mixer", + "vllm::gdn_attention", + "vllm::sparse_attn_indexer", ] def compute_hash(self) -> str: @@ -402,11 +469,12 @@ def compute_hash(self) -> str: the final hidden states. """ factors: list[Any] = [] - factors.append(self.level) + factors.append(self.mode) factors.append(self.backend) factors.append(self.custom_ops) factors.append(self.splitting_ops) factors.append(self.use_inductor) + factors.append(self.use_inductor_graph_partition) factors.append(self.inductor_compile_config) factors.append(self.inductor_passes) factors.append(self.pass_config.uuid()) @@ -452,6 +520,17 @@ def validate_cudagraph_mode_before(cls, value: Any) -> Any: return value def __post_init__(self) -> None: + if self.level is not None: + logger.warning( + "Level is deprecated and will be removed in the next release," + "either 0.12.0 or 0.11.2 whichever is soonest." + "Use mode instead." + "If both level and mode are given," + "only mode will be used." + ) + if self.mode is None: + self.mode = self.level + count_none = self.custom_ops.count("none") count_all = self.custom_ops.count("all") assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" @@ -489,6 +568,16 @@ def __post_init__(self) -> None: if isinstance(self.pass_config, dict): self.pass_config = PassConfig(**self.pass_config) + if ( + is_torch_equal_or_newer("2.9.0.dev") + and "combo_kernels" not in self.inductor_compile_config + and "benchmark_combo_kernel" not in self.inductor_compile_config + ): + # use horizontal fusion, which is useful for fusing qk-norm and + # qk-rope when query and key have different shapes. + self.inductor_compile_config["combo_kernels"] = True + self.inductor_compile_config["benchmark_combo_kernel"] = True + # migrate the deprecated flags if not self.use_cudagraph: logger.warning( @@ -539,7 +628,7 @@ def __post_init__(self) -> None: # Currently only eager and inductor backend are supported. # for piecewise compilation. Custom backends are not suppported for # piecewise compilation. Update when more backends are supported. - if self.level == CompilationLevel.PIECEWISE and self.backend not in [ + if self.mode == CompilationMode.VLLM_COMPILE and self.backend not in [ "", "eager", "inductor", @@ -550,16 +639,16 @@ def __post_init__(self) -> None: if self.use_inductor is not None: logger.warning_once( - "The 'use_inductor' flag is deprecated and will be\ - removed in a future release." + "The 'use_inductor' flag is deprecated and will be " + "removed in the next release (v0.12.0). " "Please use the 'backend' option instead.", ) self.backend = "inductor" if self.use_inductor else "eager" if self.backend == "": - self.backend = "inductor" + self.backend = current_platform.simple_compile_backend - def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: + def init_backend(self, vllm_config: "VllmConfig") -> str | Callable: """ Initialize the backend for the compilation config from a vllm config. Arguments: @@ -567,24 +656,27 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: Returns: The backend for the compilation config. """ - if self.level is None: + if self.mode is None: raise ValueError( - "No compilation level is set. This method should only be \ + "No compilation mode is set. This method should only be \ called via vllm config where the level is set if none is \ provided." ) - if self.level == CompilationLevel.NO_COMPILATION: - raise ValueError("No compilation level is set.") + if self.mode == CompilationMode.NONE: + raise ValueError("No compilation mode is set.") from torch._dynamo.backends.registry import list_backends torch_backends = list_backends(exclude_tags=tuple()) - if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]: + if self.mode in [ + CompilationMode.STOCK_TORCH_COMPILE, + CompilationMode.DYNAMO_TRACE_ONCE, + ]: if self.backend in torch_backends: return self.backend return resolve_obj_by_qualname(self.backend) - assert self.level == CompilationLevel.PIECEWISE + assert self.mode == CompilationMode.VLLM_COMPILE if self.backend not in ["eager", "inductor"]: raise ValueError( f"Invalid backend for piecewise compilation: {self.backend}" @@ -594,25 +686,12 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: return VllmBackend(vllm_config) - def init_with_cudagraph_sizes(self, cudagraph_capture_sizes: list[int]) -> None: - """To complete the initialization of config, - we need to know the cudagraph sizes.""" - - if self.cudagraph_capture_sizes is None: - self.cudagraph_capture_sizes = cudagraph_capture_sizes - else: - # de-duplicate the sizes provided by the config - dedup_sizes = list(set(self.cudagraph_capture_sizes)) - if len(dedup_sizes) < len(self.cudagraph_capture_sizes): - logger.info( - ( - "cudagraph sizes specified by model runner" - " %s is overridden by config %s" - ), - cudagraph_capture_sizes, - dedup_sizes, - ) - self.cudagraph_capture_sizes = dedup_sizes + def post_init_cudagraph_sizes(self) -> None: + """To complete the initialization after cudagraph related + configs are set. This includes: + - initialize compile_sizes + - pre-compute the mapping bs_to_padded_graph_size + """ computed_compile_sizes = [] if self.compile_sizes is not None: @@ -630,30 +709,31 @@ def init_with_cudagraph_sizes(self, cudagraph_capture_sizes: list[int]) -> None: computed_compile_sizes.append(x) self.compile_sizes = computed_compile_sizes # type: ignore - # sort to make sure cudagraph capture sizes are in descending order - self.cudagraph_capture_sizes.sort(reverse=True) - self.max_capture_size = ( - self.cudagraph_capture_sizes[0] if self.cudagraph_capture_sizes else 0 - ) + # make sure the sizes are in ascending order + self.cudagraph_capture_sizes.sort() + if self.cudagraph_capture_sizes: + assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size # pre-compute the mapping from batch size to padded graph size - self.bs_to_padded_graph_size = [0 for i in range(self.max_capture_size + 1)] + self.bs_to_padded_graph_size = [ + 0 for i in range(self.max_cudagraph_capture_size + 1) + ] for end, start in zip( - self.cudagraph_capture_sizes, self.cudagraph_capture_sizes[1:] + [0] + self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1], + [0] + self.cudagraph_capture_sizes, ): for bs in range(start, end): if bs == start: self.bs_to_padded_graph_size[bs] = start else: self.bs_to_padded_graph_size[bs] = end - self.bs_to_padded_graph_size[self.max_capture_size] = self.max_capture_size def set_splitting_ops_for_v1(self): - # NOTE: this function needs to be called only when level is - # CompilationLevel.PIECEWISE - assert self.level == CompilationLevel.PIECEWISE, ( + # NOTE: this function needs to be called only when mode is + # CompilationMode.VLLM_COMPILE + assert self.mode == CompilationMode.VLLM_COMPILE, ( "set_splitting_ops_for_v1 should only be called when " - "level is CompilationLevel.PIECEWISE" + "mode is CompilationMode.VLLM_COMPILE" ) if self.use_inductor_graph_partition: @@ -698,31 +778,25 @@ def set_splitting_ops_for_v1(self): def set_splitting_ops_for_inductor_graph_partition(self): assert self.use_inductor_graph_partition - use_inductor_graph_partition_msg = ( - "When use_inductor_graph_partition=True, splitting_ops " - "are ignored and set to an empty list. Instead, " - '"tags=(torch._C.Tag.cudagraph_unsafe, )," is ' - "used to annotate custom ops for graph partition." - ) - if self.splitting_ops is not None and len(self.splitting_ops) > 0: - logger.warning_once(use_inductor_graph_partition_msg) - self.splitting_ops = [] + if self.splitting_ops is None: + self.splitting_ops = list(self._attention_ops) def set_splitting_ops_for_attn_fusion(self): assert self.pass_config.enable_attn_fusion - if self.splitting_ops is None: - self.splitting_ops = [] - if self.cudagraph_mode.has_piecewise_cudagraphs(): - logger.warning_once( - "enable_attn_fusion is incompatible with piecewise " - "cudagraph when use_inductor_graph_partition is off." - "In this case, splitting_ops will be set to empty " - "list, and cudagraph_mode will be set to FULL. " - "Please ensure you are using attention backends that " - "support cudagraph or set cudagraph_mode to NONE " - "explicitly if encountering any problems." - ) - self.cudagraph_mode = CUDAGraphMode.FULL + # For dynamo-partition (non-inductor) attention fusion, + # set splitting_ops to empty to avoid splitting at attention ops + self.splitting_ops = [] + if self.cudagraph_mode.has_piecewise_cudagraphs(): + logger.warning_once( + "enable_attn_fusion is incompatible with piecewise " + "cudagraph when use_inductor_graph_partition is off. " + "In this case, splitting_ops will be set to empty " + "list, and cudagraph_mode will be set to FULL. " + "Please ensure you are using attention backends that " + "support cudagraph or set cudagraph_mode to NONE " + "explicitly if encountering any problems." + ) + self.cudagraph_mode = CUDAGraphMode.FULL assert not self.splitting_ops_contain_attention(), ( "attention ops should not be in splitting_ops " @@ -735,23 +809,15 @@ def splitting_ops_contain_attention(self) -> bool: ) def is_attention_compiled_piecewise(self) -> bool: - use_fx_graph_piecewise_compilation = ( - self.level == CompilationLevel.PIECEWISE - and self.splitting_ops_contain_attention() - ) + if not self.splitting_ops_contain_attention(): + return False - inductor_used = ( - self.level == CompilationLevel.PIECEWISE and self.backend == "inductor" - ) or ( - self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor" - ) - use_inductor_piecewise_compilation = ( - inductor_used - and self.use_inductor_graph_partition - and not self.splitting_ops_contain_attention() - ) + if not self.use_inductor_graph_partition: + # Dynamo-level FX split case + return self.mode == CompilationMode.VLLM_COMPILE - return use_fx_graph_piecewise_compilation or use_inductor_piecewise_compilation + # Inductor partition case + return self.backend == "inductor" and self.mode > CompilationMode.NONE def custom_op_log_check(self): """ diff --git a/vllm/config/device.py b/vllm/config/device.py index 4b6642479541..e85cd15de8cf 100644 --- a/vllm/config/device.py +++ b/vllm/config/device.py @@ -3,7 +3,7 @@ import hashlib from dataclasses import field -from typing import Any, Literal, Optional, Union +from typing import Any, Literal import torch from pydantic import ConfigDict, SkipValidation @@ -19,7 +19,7 @@ class DeviceConfig: """Configuration for the device to use for vLLM execution.""" - device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto" + device: SkipValidation[Device | torch.device | None] = "auto" """Device type for vLLM execution. This parameter is deprecated and will be removed in a future release. diff --git a/vllm/config/kv_events.py b/vllm/config/kv_events.py index 1c6bdffa1281..ce46cc03c39f 100644 --- a/vllm/config/kv_events.py +++ b/vllm/config/kv_events.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Literal + +from pydantic import Field from pydantic.dataclasses import dataclass from vllm.config.utils import config @@ -18,7 +20,7 @@ class KVEventsConfig: Events can be published externally by zmq using the event publisher config. """ - publisher: str = "null" + publisher: Literal["null", "zmq"] = Field(default=None) """The publisher to use for publishing kv events. Can be "null", "zmq". """ @@ -26,7 +28,7 @@ class KVEventsConfig: """The zmq endpoint to use for publishing kv events. """ - replay_endpoint: Optional[str] = None + replay_endpoint: str | None = None """The zmq endpoint to use for replaying kv events. """ @@ -48,3 +50,7 @@ class KVEventsConfig: """The topic to use for the event publisher. Consumers can subscribe to this topic to receive events. """ + + def __post_init__(self): + if self.publisher is None: + self.publisher = "zmq" if self.enable_kv_cache_events else "null" diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py index b33294fd66f7..dfd7ef63712a 100644 --- a/vllm/config/kv_transfer.py +++ b/vllm/config/kv_transfer.py @@ -4,7 +4,7 @@ import hashlib import uuid from dataclasses import field -from typing import Any, Literal, Optional, get_args +from typing import Any, Literal, get_args from pydantic.dataclasses import dataclass @@ -20,14 +20,14 @@ class KVTransferConfig: """Configuration for distributed KV cache transfer.""" - kv_connector: Optional[str] = None + kv_connector: str | None = None """The KV connector for vLLM to transmit KV caches between vLLM instances. """ - engine_id: Optional[str] = None + engine_id: str | None = None """The engine id for KV transfers.""" - kv_buffer_device: Optional[str] = "cuda" + kv_buffer_device: str = "cuda" """The device used by kv connector to buffer the KV cache. Choices are 'cuda' and 'cpu'.""" @@ -35,11 +35,11 @@ class KVTransferConfig: """The buffer size for TorchDistributedConnector. Measured in number of bytes. Recommended value: 1e9 (about 1GB).""" - kv_role: Optional[KVRole] = None + kv_role: KVRole | None = None """Whether this vLLM instance produces, consumes KV cache, or both. Choices are 'kv_producer', 'kv_consumer', and 'kv_both'.""" - kv_rank: Optional[int] = None + kv_rank: int | None = None """The rank of this vLLM instance in the KV cache transfer. Typical value: 0 for prefill instance, 1 for decode instance. Currently only 1P1D is supported.""" @@ -57,10 +57,13 @@ class KVTransferConfig: kv_connector_extra_config: dict[str, Any] = field(default_factory=dict) """any extra config that the connector may need.""" - kv_connector_module_path: Optional[str] = None + kv_connector_module_path: str | None = None """The Python module path to dynamically load the KV connector from. Only supported in V1.""" + enable_permute_local_kv: bool = False + """Experiment feature flag to enable HND to NHD KV Transfer""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -91,7 +94,7 @@ def __post_init__(self) -> None: if self.kv_connector is not None and self.kv_role is None: raise ValueError( - "Please specify kv_disagg_role when kv_connector " + "Please specify kv_role when kv_connector " f"is set, supported roles are {get_args(KVRole)}" ) diff --git a/vllm/config/load.py b/vllm/config/load.py index 23ce29e3983d..d625c1ac987e 100644 --- a/vllm/config/load.py +++ b/vllm/config/load.py @@ -2,9 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib -from dataclasses import field -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any +from pydantic import Field, field_validator from pydantic.dataclasses import dataclass from vllm.config.utils import config @@ -25,7 +25,7 @@ class LoadConfig: """Configuration for loading the model weights.""" - load_format: Union[str, LoadFormats] = "auto" + load_format: str | LoadFormats = "auto" """The format of the model weights to load:\n - "auto" will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available.\n @@ -48,7 +48,7 @@ class LoadConfig: - "mistral" will load weights from consolidated safetensors files used by Mistral models. - Other custom values can be supported via plugins.""" - download_dir: Optional[str] = None + download_dir: str | None = None """Directory to download and load the weights, default to the default cache directory of Hugging Face.""" safetensors_load_strategy: str = "lazy" @@ -64,21 +64,19 @@ class LoadConfig: was quantized using torchao and saved using safetensors. Needs torchao >= 0.14.0 """ - model_loader_extra_config: Union[dict, TensorizerConfig] = field( - default_factory=dict - ) + model_loader_extra_config: dict | TensorizerConfig = Field(default_factory=dict) """Extra config for model loader. This will be passed to the model loader corresponding to the chosen load_format.""" - device: Optional[str] = None + device: str | None = None """Device to which model weights will be loaded, default to device_config.device""" - ignore_patterns: Optional[Union[list[str], str]] = None + ignore_patterns: list[str] | str = Field(default_factory=lambda: ["original/**/*"]) """The list of patterns to ignore when loading the model. Default to "original/**/*" to avoid repeated loading of llama's checkpoints.""" use_tqdm_on_load: bool = True """Whether to enable tqdm for showing progress bar when loading model weights.""" - pt_load_map_location: Union[str, dict[str, str]] = "cpu" + pt_load_map_location: str | dict[str, str] = "cpu" """ pt_load_map_location: the map location for loading pytorch checkpoint, to support loading checkpoints can only be loaded on certain devices like @@ -107,12 +105,18 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str - def __post_init__(self): - self.load_format = self.load_format.lower() - if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + @field_validator("load_format", mode="after") + def _lowercase_load_format(cls, load_format: str) -> str: + return load_format.lower() + + @field_validator("ignore_patterns", mode="after") + def _validate_ignore_patterns( + cls, ignore_patterns: list[str] | str + ) -> list[str] | str: + if ignore_patterns != ["original/**/*"] and len(ignore_patterns) > 0: logger.info( "Ignoring the following patterns when downloading weights: %s", - self.ignore_patterns, + ignore_patterns, ) - else: - self.ignore_patterns = ["original/**/*"] + + return ignore_patterns diff --git a/vllm/config/lora.py b/vllm/config/lora.py index f97f2a111d41..2f9d638542b6 100644 --- a/vllm/config/lora.py +++ b/vllm/config/lora.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib -from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, ClassVar, Literal import torch -from pydantic import ConfigDict +from pydantic import ConfigDict, Field, model_validator from pydantic.dataclasses import dataclass +from typing_extensions import Self import vllm.envs as envs from vllm.config.utils import config @@ -23,6 +24,8 @@ logger = init_logger(__name__) LoRADType = Literal["auto", "float16", "bfloat16"] +MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512] +LoRAExtraVocabSize = Literal[256, 512] @config @@ -30,27 +33,34 @@ class LoRAConfig: """Configuration for LoRA.""" - max_lora_rank: int = 16 + max_lora_rank: MaxLoRARanks = 16 """Max LoRA rank.""" - max_loras: int = 1 + max_loras: int = Field(default=1, ge=1) """Max number of LoRAs in a single batch.""" fully_sharded_loras: bool = False """By default, only half of the LoRA computation is sharded with tensor parallelism. Enabling this will use the fully sharded layers. At high sequence length, max rank or tensor parallel size, this is likely faster. """ - max_cpu_loras: Optional[int] = None + max_cpu_loras: int | None = None """Maximum number of LoRAs to store in CPU memory. Must be >= than `max_loras`.""" - lora_dtype: Union[torch.dtype, LoRADType] = "auto" + lora_dtype: torch.dtype | LoRADType = "auto" """Data type for LoRA. If auto, will default to base model dtype.""" - lora_extra_vocab_size: int = 256 + lora_extra_vocab_size: LoRAExtraVocabSize = Field( + default=256, + deprecated=( + "`lora_extra_vocab_size` is deprecated and will be removed " + "in v0.12.0. Additional vocabulary support for " + "LoRA adapters is being phased out." + ), + ) """(Deprecated) Maximum size of extra vocabulary that can be present in a LoRA adapter. Will be removed in v0.12.0.""" lora_vocab_padding_size: ClassVar[int] = ( current_platform.get_lora_vocab_padding_size() ) - default_mm_loras: Optional[dict[str, str]] = None + default_mm_loras: dict[str, str] | None = None """Dictionary mapping specific modalities to LoRA model paths; this field is only applicable to multimodal models and should be leveraged when a model always expects a LoRA to be active when a given modality is present. @@ -60,9 +70,6 @@ class LoRAConfig: per prompt. When run in offline mode, the lora IDs for n modalities will be automatically assigned to 1-n with the names of the modalities in alphabetic order.""" - bias_enabled: bool = False - """[DEPRECATED] Enable bias for LoRA adapters. This option will be - removed in v0.12.0.""" def compute_hash(self) -> str: """ @@ -83,40 +90,12 @@ def compute_hash(self) -> str: factors.append(self.lora_dtype) factors.append(self.lora_extra_vocab_size) factors.append(self.lora_vocab_padding_size) - factors.append(self.bias_enabled) + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str - def __post_init__(self): - # Deprecation warning for lora_extra_vocab_size - logger.warning( - "`lora_extra_vocab_size` is deprecated and will be removed " - "in v0.12.0. Additional vocabulary support for " - "LoRA adapters is being phased out." - ) - - # Deprecation warning for enable_lora_bias - if self.bias_enabled: - logger.warning( - "`enable_lora_bias` is deprecated and will be removed in v0.12.0." - ) - - # Setting the maximum rank to 512 should be able to satisfy the vast - # majority of applications. - possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512) - possible_lora_extra_vocab_size = (256, 512) - if self.max_lora_rank not in possible_max_ranks: - raise ValueError( - f"max_lora_rank ({self.max_lora_rank}) must be one of " - f"{possible_max_ranks}." - ) - if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: - raise ValueError( - f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " - f"must be one of {possible_lora_extra_vocab_size}." - ) - if self.max_loras < 1: - raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") + @model_validator(mode="after") + def _validate_lora_config(self) -> Self: if self.max_cpu_loras is None: self.max_cpu_loras = self.max_loras elif self.max_cpu_loras < self.max_loras: @@ -125,6 +104,8 @@ def __post_init__(self): f"max_loras ({self.max_loras})" ) + return self + def verify_with_cache_config(self, cache_config: CacheConfig): if cache_config.cpu_offload_gb > 0 and not envs.VLLM_USE_V1: raise ValueError("V0 LoRA does not support CPU offload, please use V1.") diff --git a/vllm/config/model.py b/vllm/config/model.py index d0c027e47675..f81d324d8f80 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -4,18 +4,10 @@ import hashlib import json import warnings +from collections.abc import Callable from dataclasses import InitVar, field from importlib.util import find_spec -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Literal, - Optional, - Union, - cast, - get_args, -) +from typing import TYPE_CHECKING, Any, Literal, cast, get_args import torch from pydantic import ConfigDict, SkipValidation, field_validator, model_validator @@ -28,6 +20,9 @@ from vllm.config.scheduler import RunnerType from vllm.config.utils import assert_hashable, config, getattr_iter from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) from vllm.platforms import current_platform from vllm.transformers_utils.config import ( ConfigFormat, @@ -38,6 +33,7 @@ get_sentence_transformer_tokenizer_config, is_encoder_decoder, is_interleaved, + try_get_dense_modules, try_get_generation_config, try_get_safetensors_metadata, try_get_tokenizer_config, @@ -45,13 +41,16 @@ ) from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri from vllm.transformers_utils.utils import maybe_model_redirect -from vllm.utils import LayerBlockType, LazyLoader, common_broadcastable_dtype +from vllm.utils import LayerBlockType +from vllm.utils.import_utils import LazyLoader +from vllm.utils.torch_utils import common_broadcastable_dtype if TYPE_CHECKING: from transformers import PretrainedConfig import vllm.model_executor.layers.quantization as me_quant import vllm.model_executor.models as me_models + from vllm.attention.backends.registry import _Backend from vllm.config.load import LoadConfig from vllm.config.parallel import ParallelConfig from vllm.model_executor.layers.quantization import QuantizationMethods @@ -59,6 +58,7 @@ else: PretrainedConfig = Any + _Backend = Any me_quant = LazyLoader( "model_executor", globals(), "vllm.model_executor.layers.quantization" ) @@ -89,7 +89,7 @@ LogprobsMode = Literal[ "raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs" ] -HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig], PretrainedConfig]] +HfOverrides = dict[str, Any] | Callable[[PretrainedConfig], PretrainedConfig] ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"] _RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = { @@ -121,7 +121,7 @@ class ModelConfig: """Convert the model using adapters defined in [vllm.model_executor.models.adapters][]. The most common use case is to adapt a text generation model to be used for pooling tasks.""" - task: Optional[TaskOption] = None + task: TaskOption | None = None """[DEPRECATED] The task to use the model for. If the model supports more than one model runner, this is used to select which model runner to run. @@ -139,7 +139,7 @@ class ModelConfig: trust_remote_code: bool = False """Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer.""" - dtype: Union[ModelDType, torch.dtype] = "auto" + dtype: ModelDType | torch.dtype = "auto" """Data type for model weights and activations:\n - "auto" will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.\n @@ -148,33 +148,37 @@ class ModelConfig: - "bfloat16" for a balance between precision and range.\n - "float" is shorthand for FP32 precision.\n - "float32" for FP32 precision.""" - seed: Optional[int] = None + seed: int | None = None """Random seed for reproducibility. Initialized to None in V0, but initialized to 0 in V1.""" - hf_config_path: Optional[str] = None + hf_config: PretrainedConfig = field(init=False) + """The Hugging Face config of the model.""" + hf_text_config: PretrainedConfig = field(init=False) + """The Hugging Face config of the text model (same as hf_config for text models).""" + hf_config_path: str | None = None """Name or path of the Hugging Face config to use. If unspecified, model name or path will be used.""" allowed_local_media_path: str = "" """Allowing API requests to read local images or videos from directories specified by the server file system. This is a security risk. Should only be enabled in trusted environments.""" - allowed_media_domains: Optional[list[str]] = None + allowed_media_domains: list[str] | None = None """If set, only media URLs that belong to this domain can be used for multi-modal inputs. """ - revision: Optional[str] = None + revision: str | None = None """The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.""" - code_revision: Optional[str] = None + code_revision: str | None = None """The specific revision to use for the model code on the Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.""" rope_scaling: dict[str, Any] = field(default_factory=dict) """RoPE scaling configuration. For example, `{"rope_type":"dynamic","factor":2.0}`.""" - rope_theta: Optional[float] = None + rope_theta: float | None = None """RoPE theta. Use with `rope_scaling`. In some cases, changing the RoPE theta improves the performance of the scaled model.""" - tokenizer_revision: Optional[str] = None + tokenizer_revision: str | None = None """The specific revision to use for the tokenizer on the Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.""" @@ -187,9 +191,9 @@ class ModelConfig: - 1k -> 1000\n - 1K -> 1024\n - 25.6k -> 25,600""" - spec_target_max_model_len: Optional[int] = None + spec_target_max_model_len: int | None = None """Specify the maximum length for spec decoding draft models.""" - quantization: SkipValidation[Optional[QuantizationMethods]] = None + quantization: SkipValidation[QuantizationMethods | None] = None """Method used to quantize the weights. If `None`, we first check the `quantization_config` attribute in the model config file. If that is `None`, we assume the model weights are not quantized and use `dtype` to @@ -228,9 +232,11 @@ class ModelConfig: output will contain token ids.""" enable_prompt_embeds: bool = False """If `True`, enables passing text embeddings as inputs via the - `prompt_embeds` key. Note that enabling this will double the time required - for graph compilation.""" - served_model_name: Optional[Union[str, list[str]]] = None + `prompt_embeds` key. + + WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed. + Only enable this flag for trusted users!""" + served_model_name: str | list[str] | None = None """The model name(s) used in the API. If multiple names are provided, the server will respond to any of the provided names. The model name in the model field of a response will be the first name in this list. If not @@ -238,20 +244,20 @@ class ModelConfig: that this name(s) will also be used in `model_name` tag content of prometheus metrics, if multiple names provided, metrics tag will take the first one.""" - config_format: Union[str, ConfigFormat] = "auto" + config_format: str | ConfigFormat = "auto" """The format of the model config to load:\n - "auto" will try to load the config in hf format if available else it will try to load in mistral format.\n - "hf" will load the config in hf format.\n - "mistral" will load the config in mistral format.""" - hf_token: Optional[Union[bool, str]] = None + hf_token: bool | str | None = None """The token to use as HTTP bearer authorization for remote files . If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).""" hf_overrides: HfOverrides = field(default_factory=dict) """If a dictionary, contains arguments to be forwarded to the Hugging Face config. If a callable, it is called to update the HuggingFace config.""" - logits_processor_pattern: Optional[str] = None + logits_processor_pattern: str | None = None """Optional regex pattern specifying valid logits processor qualified names that can be passed with the `logits_processors` extra completion argument. Defaults to `None`, which allows no processors.""" @@ -269,7 +275,7 @@ class ModelConfig: `--generation-config vllm`, only the override parameters are used.""" enable_sleep_mode: bool = False """Enable sleep mode for the engine (only cuda platform is supported).""" - model_impl: Union[str, ModelImpl] = "auto" + model_impl: str | ModelImpl = "auto" """Which implementation of the model to use:\n - "auto" will try to use the vLLM implementation, if it exists, and fall back to the Transformers implementation if no vLLM implementation is @@ -278,36 +284,38 @@ class ModelConfig: - "transformers" will use the Transformers model implementation.\n - "terratorch" will use the TerraTorch model implementation. """ - override_attention_dtype: Optional[str] = None + override_attention_dtype: str | None = None """Override dtype for attention""" - logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None + logits_processors: list[str | type[LogitsProcessor]] | None = None """One or more logits processors' fully-qualified class names or class definitions""" - io_processor_plugin: Optional[str] = None + io_processor_plugin: str | None = None """IOProcessor plugin name to load at model startup""" # Pooler config - pooler_config: Optional[PoolerConfig] = None + pooler_config: PoolerConfig | None = None """Pooler config which controls the behaviour of output pooling in pooling models.""" - override_pooler_config: Optional[Union[dict, PoolerConfig]] = None + override_pooler_config: dict | PoolerConfig | None = None """[DEPRECATED] Use `pooler_config` instead. This field will be removed in v0.12.0 or v1.0.0, whichever is sooner.""" # Multimodal config and init vars - multimodal_config: Optional[MultiModalConfig] = None + multimodal_config: MultiModalConfig | None = None """Configuration for multimodal model. If `None`, this will be inferred from the architecture of `self.model`.""" - limit_mm_per_prompt: InitVar[Optional[dict[str, Union[int, dict[str, int]]]]] = None - media_io_kwargs: InitVar[Optional[dict[str, dict[str, Any]]]] = None - mm_processor_kwargs: InitVar[Optional[dict[str, Any]]] = None - mm_processor_cache_gb: InitVar[Optional[float]] = None - mm_processor_cache_type: InitVar[Optional[MMCacheType]] = None - mm_shm_cache_max_object_size_mb: InitVar[Optional[int]] = None - mm_encoder_tp_mode: InitVar[Optional[MMEncoderTPMode]] = None - interleave_mm_strings: InitVar[Optional[bool]] = None - skip_mm_profiling: InitVar[Optional[bool]] = None - video_pruning_rate: InitVar[Optional[float]] = None + limit_mm_per_prompt: InitVar[dict[str, int | dict[str, int]] | None] = None + enable_mm_embeds: InitVar[bool | None] = None + media_io_kwargs: InitVar[dict[str, dict[str, Any]] | None] = None + mm_processor_kwargs: InitVar[dict[str, Any] | None] = None + mm_processor_cache_gb: InitVar[float | None] = None + mm_processor_cache_type: InitVar[MMCacheType | None] = None + mm_shm_cache_max_object_size_mb: InitVar[int | None] = None + mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None + mm_encoder_attn_backend: InitVar[_Backend | str | None] = None + interleave_mm_strings: InitVar[bool | None] = None + skip_mm_profiling: InitVar[bool | None] = None + video_pruning_rate: InitVar[float | None] = None def compute_hash(self) -> str: """ @@ -369,7 +377,7 @@ def compute_hash(self) -> str: def _update_nested( self, - target: Union["PretrainedConfig", dict[str, Any]], + target: PretrainedConfig | dict[str, Any], updates: dict[str, Any], ) -> None: """Recursively updates a config or dict with nested updates.""" @@ -397,7 +405,7 @@ def _update_nested( def _apply_dict_overrides( self, - config: "PretrainedConfig", + config: PretrainedConfig, overrides: dict[str, Any], ) -> None: """Apply dict overrides, handling both nested configs and dict values.""" @@ -415,17 +423,23 @@ def _apply_dict_overrides( def __post_init__( self, # Multimodal config init vars - limit_mm_per_prompt: Optional[dict[str, int]], - media_io_kwargs: Optional[dict[str, dict[str, Any]]], - mm_processor_kwargs: Optional[dict[str, Any]], - mm_processor_cache_gb: Optional[float], - mm_processor_cache_type: Optional[MMCacheType], - mm_shm_cache_max_object_size_mb: Optional[int], - mm_encoder_tp_mode: Optional[MMEncoderTPMode], - interleave_mm_strings: Optional[bool], - skip_mm_profiling: Optional[bool], - video_pruning_rate: Optional[float], + limit_mm_per_prompt: dict[str, int] | None, + enable_mm_embeds: bool | None, + media_io_kwargs: dict[str, dict[str, Any]] | None, + mm_processor_kwargs: dict[str, Any] | None, + mm_processor_cache_gb: float | None, + mm_processor_cache_type: MMCacheType | None, + mm_shm_cache_max_object_size_mb: int | None, + mm_encoder_tp_mode: MMEncoderTPMode | None, + mm_encoder_attn_backend: _Backend | str | None, + interleave_mm_strings: bool | None, + skip_mm_profiling: bool | None, + video_pruning_rate: float | None, ) -> None: + # Enable batch invariance settings if requested + if vllm_is_batch_invariant(): + self.enforce_eager = True + # Set the default seed to 0 in V1. # NOTE(woosuk): In V0, we set the default seed to None because the # driver worker shares the same process as the user process, and thus @@ -721,12 +735,14 @@ def _task_to_convert(task: TaskOption) -> ConvertType: mm_config_kwargs = dict( limit_per_prompt=limit_mm_per_prompt, + enable_mm_embeds=enable_mm_embeds, media_io_kwargs=media_io_kwargs, mm_processor_kwargs=mm_processor_kwargs, mm_processor_cache_gb=mm_processor_cache_gb, mm_processor_cache_type=mm_processor_cache_type, mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb, mm_encoder_tp_mode=mm_encoder_tp_mode, + mm_encoder_attn_backend=mm_encoder_attn_backend, interleave_mm_strings=interleave_mm_strings, skip_mm_profiling=skip_mm_profiling, video_pruning_rate=video_pruning_rate, @@ -771,35 +787,29 @@ def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": def _get_transformers_backend_cls(self) -> str: """Determine which Transformers backend class will be used if `model_impl` is set to `transformers` or `auto`.""" - prefix = "Transformers" - prefix += "MoE" if self.get_num_experts() > 1 else "" + cls = "Transformers" + # If 'hf_config != hf_text_config' it's a nested config, i.e. multimodal + cls += "MultiModal" if self.hf_config != self.hf_text_config else "" + cls += "MoE" if self.get_num_experts() > 1 else "" # Check if the architecture we're wrapping has defaults runner = None - convert = None + task = None if defaults := try_match_architecture_defaults(self.architectures[0]): - _, (runner, convert) = defaults - # Overwrite with user-specified values + _, (runner, task) = defaults + # User specified value take precedence if self.runner != "auto": runner = self.runner - if self.convert not in {"auto", "none"}: - convert = self.convert - # Fall back to default values if still not set - if runner is None: - runner = "generate" - if convert in {None, "none"}: - convert = "embed" - # Resolve Transformers backend pooling classes - if runner == "pooling": - if convert == "embed": - return prefix + "EmbeddingModel" - if convert == "classify": - return prefix + "ForSequenceClassification" - # Resolve Transformers backend generate classes - if self.hf_config != self.hf_text_config: - # If 'hf_text_config' is the same as 'hf_config'. If not, it is - # probably a composite config, i.e. multimodal - return prefix + "ForMultimodalLM" - return prefix + "ForCausalLM" + # Only consider Transformers backend pooling classes if we're wrapping an + # architecture that defaults to pooling. Otherwise, we return the LM class + # and use adapters. + if runner == "pooling" and task in {"embed", "classify"}: + if task == "embed": + cls += "EmbeddingModel" + elif task == "classify": + cls += "ForSequenceClassification" + else: + cls += "ForCausalLM" + return cls def using_transformers_backend(self) -> bool: """Check if the model is using the Transformers backend class.""" @@ -1209,7 +1219,24 @@ def verify_with_parallel_config( "Supported models implement the `SupportsPP` interface." ) - def get_sliding_window(self) -> Optional[int]: + decode_context_parallel_size = parallel_config.decode_context_parallel_size + if decode_context_parallel_size > 1 and not self.use_mla: + total_num_kv_heads = self.get_total_num_kv_heads() + assert tensor_parallel_size > total_num_kv_heads, ( + f"tensor parallel size {tensor_parallel_size} must be greater " + f"than total num kv heads {total_num_kv_heads} when enable " + f"decode context parallel for GQA/MQA" + ) + + max_dcp_size = tensor_parallel_size // total_num_kv_heads + assert decode_context_parallel_size <= max_dcp_size, ( + f"decode context parallel size must less than or equal to " + f"(tensor parallel size {tensor_parallel_size} // total " + f"num kv heads {total_num_kv_heads}) = {max_dcp_size}, " + f"but got {decode_context_parallel_size}" + ) + + def get_sliding_window(self) -> int | None: """Get the sliding window size from the HF text config if present.""" return getattr(self.hf_text_config, "sliding_window", None) @@ -1479,7 +1506,7 @@ def get_num_layers_by_block_type( f"{block_type.value} layers" ) - def get_mamba_chunk_size(self) -> Optional[int]: + def get_mamba_chunk_size(self) -> int | None: """ Returns the mamba chunk size if it exists """ @@ -1629,10 +1656,6 @@ def has_noops(self) -> bool: def has_inner_state(self): return self._model_info.has_inner_state - @property - def is_v1_compatible(self) -> bool: - return not self._model_info.supports_v0_only - @property def use_mla(self) -> bool: return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE @@ -1689,6 +1712,20 @@ def head_dtype(self) -> torch.dtype: logger.debug_once("head dtype: %s", head_dtype) return head_dtype + @property + def hidden_size(self): + if hasattr(self.hf_config, "hidden_size"): + return self.hf_config.hidden_size + text_config = self.hf_config.get_text_config() + return text_config.hidden_size + + @property + def embedding_size(self): + dense_modules = try_get_dense_modules(self.model, revision=self.revision) + if dense_modules is not None: + return dense_modules[-1]["out_features"] + return self.hidden_size + def get_and_verify_max_len(self, max_model_len: int): # Consider max_model_len in tokenizer_config only when # pooling models use absolute position_embedding. @@ -1715,9 +1752,7 @@ def get_and_verify_max_len(self, max_model_len: int): return max_model_len -def get_served_model_name( - model: str, served_model_name: Optional[Union[str, list[str]]] -): +def get_served_model_name(model: str, served_model_name: str | list[str] | None): """ If the input is a non-empty list, the first model_name in `served_model_name` is taken. @@ -1761,9 +1796,9 @@ def iter_architecture_defaults(): def try_match_architecture_defaults( architecture: str, *, - runner_type: Optional[RunnerType] = None, - convert_type: Optional[ConvertType] = None, -) -> Optional[tuple[str, tuple[RunnerType, ConvertType]]]: + runner_type: RunnerType | None = None, + convert_type: ConvertType | None = None, +) -> tuple[str, tuple[RunnerType, ConvertType]] | None: for suffix, ( default_runner_type, default_convert_type, @@ -1817,20 +1852,20 @@ def _find_dtype( model_id: str, config: PretrainedConfig, *, - revision: Optional[str], + revision: str | None, ): - # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct - # because config.torch_dtype can be None. - config_dtype = getattr(config, "torch_dtype", None) + # NOTE: getattr(config, "dtype", torch.float32) is not correct + # because config.dtype can be None. + config_dtype = getattr(config, "dtype", None) # Fallbacks for multi-modal models if the root config - # does not define torch_dtype + # does not define dtype if config_dtype is None: - config_dtype = getattr(config.get_text_config(), "torch_dtype", None) + config_dtype = getattr(config.get_text_config(), "dtype", None) if config_dtype is None and hasattr(config, "vision_config"): - config_dtype = getattr(config.vision_config, "torch_dtype", None) + config_dtype = getattr(config.vision_config, "dtype", None) if config_dtype is None and hasattr(config, "encoder_config"): - config_dtype = getattr(config.encoder_config, "torch_dtype", None) + config_dtype = getattr(config.encoder_config, "dtype", None) # Try to read the dtype of the weights if they are in safetensors format if config_dtype is None: @@ -1902,10 +1937,10 @@ def _resolve_auto_dtype( def _get_and_verify_dtype( model_id: str, config: PretrainedConfig, - dtype: Union[str, torch.dtype], + dtype: str | torch.dtype, *, is_pooling_model: bool, - revision: Optional[str] = None, + revision: str | None = None, ) -> torch.dtype: config_dtype = _find_dtype(model_id, config, revision=revision) model_type = config.model_type @@ -1947,7 +1982,7 @@ def _get_and_verify_dtype( def _get_head_dtype( config: PretrainedConfig, dtype: torch.dtype, runner_type: str ) -> torch.dtype: - head_dtype: Optional[Union[str, torch.dtype]] = getattr(config, "head_dtype", None) + head_dtype: str | torch.dtype | None = getattr(config, "head_dtype", None) if head_dtype == "model": return dtype @@ -1970,12 +2005,12 @@ def _get_head_dtype( def _get_and_verify_max_len( hf_config: PretrainedConfig, - tokenizer_config: Optional[dict], - max_model_len: Optional[int], + tokenizer_config: dict | None, + max_model_len: int | None, disable_sliding_window: bool, - sliding_window: Optional[int], - spec_target_max_model_len: Optional[int] = None, - encoder_config: Optional[Any] = None, + sliding_window: int | None, + spec_target_max_model_len: int | None = None, + encoder_config: Any | None = None, ) -> int: """Get and verify the model's maximum length.""" derived_max_model_len = float("inf") diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index fc8d2262dcb4..ef73720efe09 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -3,14 +3,18 @@ import hashlib from collections.abc import Mapping -from dataclasses import field -from typing import Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, TypeAlias -from pydantic import ConfigDict, Field, field_validator +from pydantic import ConfigDict, Field, field_validator, model_validator from pydantic.dataclasses import dataclass from vllm.config.utils import config +if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend +else: + _Backend = Any + @dataclass class BaseDummyOptions: @@ -23,31 +27,31 @@ class BaseDummyOptions: class VideoDummyOptions(BaseDummyOptions): """Options for generating dummy video data during profiling.""" - num_frames: Optional[int] = Field(None, gt=0) - width: Optional[int] = Field(None, gt=0) - height: Optional[int] = Field(None, gt=0) + num_frames: int | None = Field(None, gt=0) + width: int | None = Field(None, gt=0) + height: int | None = Field(None, gt=0) @dataclass(config=ConfigDict(extra="forbid")) class ImageDummyOptions(BaseDummyOptions): """Options for generating dummy image data during profiling.""" - width: Optional[int] = Field(None, gt=0) - height: Optional[int] = Field(None, gt=0) + width: int | None = Field(None, gt=0) + height: int | None = Field(None, gt=0) @dataclass(config=ConfigDict(extra="forbid")) class AudioDummyOptions(BaseDummyOptions): """Options for generating dummy audio data during profiling.""" - length: Optional[int] = Field(None, gt=0) + length: int | None = Field(None, gt=0) MMEncoderTPMode = Literal["weights", "data"] MMCacheType = Literal["shm", "lru"] -DummyOptions = Union[ - BaseDummyOptions, VideoDummyOptions, ImageDummyOptions, AudioDummyOptions -] +DummyOptions: TypeAlias = ( + BaseDummyOptions | VideoDummyOptions | ImageDummyOptions | AudioDummyOptions +) @config @@ -55,7 +59,7 @@ class AudioDummyOptions(BaseDummyOptions): class MultiModalConfig: """Controls the behavior of multimodal models.""" - limit_per_prompt: dict[str, DummyOptions] = field(default_factory=dict) + limit_per_prompt: dict[str, DummyOptions] = Field(default_factory=dict) """The maximum number of input items and options allowed per prompt for each modality. Defaults to 999 for each modality. @@ -71,11 +75,19 @@ class MultiModalConfig: {"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512, "height": 512}} """ - media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) + enable_mm_embeds: bool = False + """If `True`, enables passing multimodal embeddings: + for `LLM` class, this refers to tensor inputs under `multi_modal_data`; + for the OpenAI-compatible server, this refers to chat messages with content + `"type": "*_embeds"`. + + WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed. + Only enable this flag for trusted users!""" + media_io_kwargs: dict[str, dict[str, Any]] = Field(default_factory=dict) """Additional args passed to process media inputs, keyed by modalities. For example, to set num_frames for video, set `--media-io-kwargs '{"video": {"num_frames": 40} }'`""" - mm_processor_kwargs: Optional[dict[str, object]] = None + mm_processor_kwargs: dict[str, object] | None = None """Arguments to be forwarded to the model's processor for multi-modal data, e.g., image processor. Overrides for the multi-modal processor obtained from `transformers.AutoProcessor.from_pretrained`. @@ -84,7 +96,7 @@ class MultiModalConfig: For example, for Phi-3-Vision: `{"num_crops": 4}`.""" - mm_processor_cache_gb: float = 4 + mm_processor_cache_gb: float = Field(default=4, ge=0) """The size (in GiB) of the multi-modal processor cache, which is used to avoid re-processing past multi-modal inputs. @@ -96,7 +108,7 @@ class MultiModalConfig: mm_processor_cache_type: MMCacheType = "lru" """Type of cache to use for the multi-modal preprocessor/mapper. If `shm`, use shared memory FIFO cache. If `lru`, use mirrored LRU cache.""" - mm_shm_cache_max_object_size_mb: int = 128 + mm_shm_cache_max_object_size_mb: int = Field(default=128, ge=0) """Size limit (in MiB) for each object stored in the multi-modal processor shared memory cache. Only effective when `mm_processor_cache_type` is `"shm"`.""" @@ -113,6 +125,10 @@ class MultiModalConfig: DP (which is controlled by `--data-parallel-size`). This is only supported on a per-model basis and falls back to `"weights"` if the encoder does not support DP.""" + mm_encoder_attn_backend: _Backend | None = None + """Optional override for the multi-modal encoder attention backend when + using vision transformers. Accepts any value from + `vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`).""" interleave_mm_strings: bool = False """Enable fully interleaved support for multimodal prompts, while using --chat-template-content-format=string.""" @@ -123,7 +139,7 @@ class MultiModalConfig: This reduces engine startup time but shifts the responsibility to users for estimating the peak memory usage of the activation of multimodal encoder and embedding cache.""" - video_pruning_rate: Optional[float] = None + video_pruning_rate: float | None = Field(default=None, ge=0.0, lt=1.0) """Sets pruning rate for video pruning via Efficient Video Sampling. Value sits in range [0;1) and determines fraction of media tokens from each video to be pruned. @@ -132,7 +148,7 @@ class MultiModalConfig: @field_validator("limit_per_prompt", mode="before") @classmethod def _validate_limit_per_prompt( - cls, value: dict[str, Union[int, dict[str, int]]] + cls, value: dict[str, int | dict[str, int]] ) -> dict[str, DummyOptions]: for k, v in value.items(): # Handle legacy format where only count is specified @@ -149,6 +165,41 @@ def _validate_limit_per_prompt( value[k] = BaseDummyOptions(**v) return value + @field_validator("mm_encoder_attn_backend", mode="before") + @classmethod + def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None: + from vllm.attention.backends.registry import ( + _Backend as BackendEnum, + ) + from vllm.attention.backends.registry import ( + backend_name_to_enum, + ) + + if value is None or isinstance(value, BackendEnum): + return value + + if isinstance(value, str): + candidate = backend_name_to_enum(value.upper()) + if candidate is not None: + return candidate + + valid_backends = ", ".join(sorted(BackendEnum.__members__.keys())) + raise ValueError( + f"Invalid mm encoder attention backend. Expected one of: {valid_backends}." + ) + + @model_validator(mode="after") + def _validate_multimodal_config(self): + if self.mm_processor_cache_type != "shm" and ( + self.mm_shm_cache_max_object_size_mb + != MultiModalConfig.mm_shm_cache_max_object_size_mb + ): + raise ValueError( + "'mm_shm_cache_max_object_size_mb' should only be set when " + "'mm_processor_cache_type' is 'shm'." + ) + return self + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -161,9 +212,11 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] + factors: list[Any] = [ + self.mm_encoder_attn_backend.name + if self.mm_encoder_attn_backend is not None + else None + ] hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str @@ -179,7 +232,7 @@ def get_limit_per_prompt(self, modality: str) -> int: return 999 return limit_data.count - def get_dummy_options(self, modality: str) -> Optional[BaseDummyOptions]: + def get_dummy_options(self, modality: str) -> BaseDummyOptions | None: """ Get the configurable dummy data options for a modality. Returns None if no options are configured for this modality. diff --git a/vllm/config/observability.py b/vllm/config/observability.py index 6c7b5fbbee47..564c4f7aed41 100644 --- a/vllm/config/observability.py +++ b/vllm/config/observability.py @@ -3,8 +3,10 @@ import hashlib from functools import cached_property -from typing import Any, Literal, Optional, cast +from typing import Any, Literal, cast +from packaging.version import parse +from pydantic import field_validator, model_validator from pydantic.dataclasses import dataclass from vllm import version @@ -18,7 +20,7 @@ class ObservabilityConfig: """Configuration for observability - metrics and tracing.""" - show_hidden_metrics_for_version: Optional[str] = None + show_hidden_metrics_for_version: str | None = None """Enable deprecated Prometheus metrics that have been hidden since the specified version. For example, if a previously deprecated metric has been hidden since the v0.7.0 release, you use @@ -33,10 +35,10 @@ def show_hidden_metrics(self) -> bool: return False return version._prev_minor_version_was(self.show_hidden_metrics_for_version) - otlp_traces_endpoint: Optional[str] = None + otlp_traces_endpoint: str | None = None """Target URL to which OpenTelemetry traces will be sent.""" - collect_detailed_traces: Optional[list[DetailedTraceModules]] = None + collect_detailed_traces: list[DetailedTraceModules] | None = None """It makes sense to set this only if `--otlp-traces-endpoint` is set. If set, it will collect detailed traces for the specified modules. This involves use of possibly costly and or blocking operations and hence might @@ -79,25 +81,43 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str - def __post_init__(self): - if ( - self.collect_detailed_traces is not None - and len(self.collect_detailed_traces) == 1 - and "," in self.collect_detailed_traces[0] - ): - self._parse_collect_detailed_traces() - - from vllm.tracing import is_otel_available, otel_import_error_traceback - - if not is_otel_available() and self.otlp_traces_endpoint is not None: + @field_validator("show_hidden_metrics_for_version") + @classmethod + def _validate_show_hidden_metrics_for_version(cls, value: str | None) -> str | None: + if value is not None: + # Raises an exception if the string is not a valid version. + parse(value) + return value + + @field_validator("otlp_traces_endpoint") + @classmethod + def _validate_otlp_traces_endpoint(cls, value: str | None) -> str | None: + if value is not None: + from vllm.tracing import is_otel_available, otel_import_error_traceback + + if not is_otel_available(): + raise ValueError( + "OpenTelemetry is not available. Unable to configure " + "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are " + f"installed. Original error:\n{otel_import_error_traceback}" + ) + return value + + @field_validator("collect_detailed_traces") + @classmethod + def _validate_collect_detailed_traces( + cls, value: list[DetailedTraceModules] | None + ) -> list[DetailedTraceModules] | None: + """Handle the legacy case where users might provide a comma-separated + string instead of a list of strings.""" + if value is not None and len(value) == 1 and "," in value[0]: + value = cast(list[DetailedTraceModules], value[0].split(",")) + return value + + @model_validator(mode="after") + def _validate_tracing_config(self): + if self.collect_detailed_traces and not self.otlp_traces_endpoint: raise ValueError( - "OpenTelemetry is not available. Unable to configure " - "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are " - f"installed. Original error:\n{otel_import_error_traceback}" + "collect_detailed_traces requires `--otlp-traces-endpoint` to be set." ) - - def _parse_collect_detailed_traces(self): - assert isinstance(self.collect_detailed_traces, list) - self.collect_detailed_traces = cast( - list[DetailedTraceModules], self.collect_detailed_traces[0].split(",") - ) + return self diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index c5fe3b97f8a3..e8847354bb09 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -3,11 +3,10 @@ import hashlib import os -from dataclasses import field -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal import torch -from pydantic import model_validator +from pydantic import Field, model_validator from pydantic.dataclasses import dataclass from torch.distributed import ProcessGroup, ReduceOp from typing_extensions import Self @@ -15,23 +14,28 @@ import vllm.envs as envs from vllm.config.utils import config from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) from vllm.platforms import current_platform -from vllm.utils import cuda_device_count_stateless, get_open_ports_list +from vllm.utils.network_utils import get_open_ports_list +from vllm.utils.torch_utils import cuda_device_count_stateless if TYPE_CHECKING: from ray.runtime_env import RuntimeEnv from ray.util.placement_group import PlacementGroup - from vllm.executor.executor_base import ExecutorBase + from vllm.v1.executor import Executor else: RuntimeEnv = Any PlacementGroup = Any - ExecutorBase = Any + Executor = Any logger = init_logger(__name__) ExpertPlacementStrategy = Literal["linear", "round_robin"] DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] +DataParallelBackend = Literal["ray", "mp"] @config @@ -49,7 +53,7 @@ class EPLBConfig: of the last `lb_window_size` steps will be used for rearranging experts. """ - num_redundant_experts: int = 0 + num_redundant_experts: int = Field(default=0, ge=0) """Number of redundant experts to use for expert parallelism.""" log_balancedness: bool = False @@ -75,7 +79,7 @@ class ParallelConfig: """Number of local data parallel groups.""" data_parallel_rank: int = 0 """Rank of the data parallel group.""" - data_parallel_rank_local: Optional[int] = None + data_parallel_rank_local: int | None = None """Local rank of the data parallel group, set only in SPMD mode.""" data_parallel_master_ip: str = "127.0.0.1" @@ -84,7 +88,7 @@ class ParallelConfig: """Port for data parallel messaging.""" data_parallel_master_port: int = 29500 """Port of the data parallel master.""" - data_parallel_backend: str = "mp" + data_parallel_backend: DataParallelBackend = "mp" """Backend to use for data parallel, either "mp" or "ray".""" data_parallel_external_lb: bool = False """Whether to use "external" DP LB mode. Applies only to online serving @@ -102,7 +106,7 @@ class ParallelConfig: """Use expert parallelism instead of tensor parallelism for MoE layers.""" enable_eplb: bool = False """Enable expert parallelism load balancing for MoE layers.""" - eplb_config: EPLBConfig = field(default_factory=EPLBConfig) + eplb_config: EPLBConfig = Field(default_factory=EPLBConfig) """Expert parallelism configuration.""" expert_placement_strategy: ExpertPlacementStrategy = "linear" """The expert placement strategy for MoE layers:\n @@ -113,24 +117,43 @@ class ParallelConfig: with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1 will have experts [1, 3]. This strategy can help improve load balancing for grouped expert models with no redundant experts.""" - num_redundant_experts: Optional[int] = None + all2all_backend: ( + Literal[ + "naive", + "pplx", + "deepep_high_throughput", + "deepep_low_latency", + "allgather_reducescatter", + "flashinfer_all2allv", + ] + | None + ) = None + """All2All backend for MoE expert parallel communication. If not set, uses + the value from VLLM_ALL2ALL_BACKEND environment variable. Available options: + - "naive": Naive all2all implementation using broadcasts + - "allgather_reducescatter": All2all based on allgather and reducescatter + - "pplx": Use pplx kernels + - "deepep_high_throughput": Use deepep high-throughput kernels + - "deepep_low_latency": Use deepep low-latency kernels + - "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl""" + num_redundant_experts: int | None = None """`num_redundant_experts` is deprecated and has been replaced with `eplb_config.num_redundant_experts`. This will be removed in v0.12.0. Please use `eplb_config.num_redundant_experts` instead.""" - eplb_window_size: Optional[int] = None + eplb_window_size: int | None = None """`eplb_window_size` is deprecated and has been replaced with `eplb_config.window_size`. This will be removed in v0.12.0. Please use `eplb_config.window_size` instead.""" - eplb_step_interval: Optional[int] = None + eplb_step_interval: int | None = None """`eplb_step_interval` is deprecated and has been replaced with `eplb_config.step_interval`. This will be removed in v0.12.0. Please use `eplb_config.step_interval` instead.""" - eplb_log_balancedness: Optional[bool] = None + eplb_log_balancedness: bool | None = None """`eplb_log_balancedness` is deprecated and has been replaced with `eplb_config.log_balancedness`. This will be removed in v0.12.0. Please use `eplb_config.log_balancedness` instead.""" - max_parallel_loading_workers: Optional[int] = None + max_parallel_loading_workers: int | None = None """Maximum number of parallel loading workers when loading model sequentially in multiple batches. To avoid RAM OOM when using tensor parallel and large models.""" @@ -159,15 +182,15 @@ class ParallelConfig: ray_workers_use_nsight: bool = False """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" - ray_runtime_env: Optional[RuntimeEnv] = None + ray_runtime_env: RuntimeEnv | None = None """Ray runtime environment to pass to distributed workers.""" - placement_group: Optional[PlacementGroup] = None + placement_group: PlacementGroup | None = None """ray distributed model workers placement group.""" - distributed_executor_backend: Optional[ - Union[str, DistributedExecutorBackend, type[ExecutorBase]] - ] = None + distributed_executor_backend: ( + str | DistributedExecutorBackend | type[Executor] | None + ) = None """Backend to use for distributed model workers, either "ray" or "mp" (multiprocessing). If the product of pipeline_parallel_size and tensor_parallel_size is less than @@ -188,13 +211,13 @@ class is dynamically inherited by the worker class. This is used to inject new attributes and methods to the worker class for use in collective_rpc calls.""" - world_size: int = field(init=False) + world_size: int = Field(init=False) """world_size is TPxPP, it affects the number of workers we create.""" rank: int = 0 """Global rank in distributed setup.""" - _data_parallel_master_port_list: list[int] = field(default_factory=list) + _data_parallel_master_port_list: list[int] = Field(default_factory=list) """List of open port auto-queried for data parallel messaging. Set to be private as it's not intended to be configured by users. """ @@ -204,7 +227,7 @@ class is dynamically inherited by the worker class. This is used to inject not change by dcp, it simply reuse the GPUs of TP group, and tp_size needs to be divisible by dcp_size.""" - _api_process_count: int = 1 + _api_process_count: int = Field(default=1, gt=0) """ The number of API processes initialized. @@ -213,7 +236,7 @@ class is dynamically inherited by the worker class. This is used to inject should only be set by API server scale-out. """ - _api_process_rank: int = 0 + _api_process_rank: int = Field(default=0, ge=-1) """ The rank of this API process, or `-1` for engine core processes under API server scale-out. @@ -223,6 +246,51 @@ class is dynamically inherited by the worker class. This is used to inject should only be set by API server scale-out. """ + @model_validator(mode="after") + def _validate_parallel_config(self) -> Self: + if self._api_process_rank >= self._api_process_count: + raise ValueError( + "Invalid value of `_api_process_rank`. " + f"Expected to be `-1` or `[0, {self._api_process_count})`, " + f"but found: {self._api_process_rank}" + ) + + if self.data_parallel_size_local > self.data_parallel_size: + raise ValueError( + f"data_parallel_size_local ({self.data_parallel_size_local}) " + f"must be <= data_parallel_size ({self.data_parallel_size})" + ) + + if self.data_parallel_size <= 1 and self.data_parallel_external_lb: + raise ValueError( + "data_parallel_external_lb can only be set when data_parallel_size > 1" + ) + + if self.enable_eplb: + if not current_platform.is_cuda(): + raise ValueError( + "Expert parallelism load balancing is only supported on " + "CUDA devices now." + ) + if not self.enable_expert_parallel: + raise ValueError("enable_expert_parallel must be True to use EPLB.") + if self.tensor_parallel_size * self.data_parallel_size <= 1: + raise ValueError( + "EPLB requires tensor_parallel_size or data_parallel_size " + f"to be greater than 1, but got " + f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}." + ) + else: + if self.eplb_config.num_redundant_experts != 0: + raise ValueError( + "num_redundant_experts is set to " + f"{self.eplb_config.num_redundant_experts} but EPLB is not " + "enabled. Either enable EPLB or unset " + "num_redundant_experts." + ) + + return self + @property def world_size_across_dp(self) -> int: """world_size_across_dp is TPxPPxDP, it is the size of the world @@ -261,7 +329,7 @@ def stateless_init_dp_group(self) -> ProcessGroup: ) max_retries = 5 - last_exc: Optional[Exception] = None + last_exc: Exception | None = None for _ in range(max_retries): try: # use gloo since the engine process might not have cuda device @@ -270,7 +338,7 @@ def stateless_init_dp_group(self) -> ProcessGroup: self.get_next_dp_init_port(), self.data_parallel_rank, self.data_parallel_size, - backend="gloo", + backend=current_platform.dist_backend, ) except DistNetworkError as e: # We only want to retry when the root cause is EADDRINUSE. @@ -296,7 +364,7 @@ def stateless_init_dp_group(self) -> ProcessGroup: @property def use_sequence_parallel_moe(self) -> bool: return ( - envs.VLLM_ALL2ALL_BACKEND + self.all2all_backend in ( "allgather_reducescatter", "naive", @@ -345,7 +413,7 @@ def compute_hash(self): factors.append(self.tensor_parallel_size) factors.append(self.enable_expert_parallel) factors.append(self.data_parallel_size) - factors.append(envs.VLLM_ALL2ALL_BACKEND) + factors.append(self.all2all_backend) factors.append(self.enable_eplb) if self.enable_eplb: factors.append(self.eplb_config.log_balancedness) @@ -355,6 +423,16 @@ def compute_hash(self): return hashlib.sha256(str(factors).encode()).hexdigest() def __post_init__(self) -> None: + # Set all2all_backend from env var if not specified, with deprecation warning + if self.all2all_backend is None: + self.all2all_backend = envs.VLLM_ALL2ALL_BACKEND + if envs.is_set("VLLM_ALL2ALL_BACKEND"): + logger.warning_once( + "VLLM_ALL2ALL_BACKEND environment variable is deprecated and " + "will be removed in a future release. Please use the " + "--all2all-backend command-line argument instead." + ) + # Forward deprecated fields to their new location if self.num_redundant_experts is not None: self.eplb_config.num_redundant_experts = self.num_redundant_experts @@ -396,12 +474,6 @@ def __post_init__(self) -> None: logger.info("Using external launcher for distributed inference.") self.world_size *= self.data_parallel_size - if self.data_parallel_size_local > self.data_parallel_size: - raise ValueError( - f"data_parallel_size_local ({self.data_parallel_size_local}) " - f"must be <= data_parallel_size ({self.data_parallel_size})" - ) - if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: # Data parallel was specified in the engine args. if self.distributed_executor_backend == "external_launcher": @@ -431,48 +503,15 @@ def __post_init__(self) -> None: self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT - if self.data_parallel_external_lb: - raise ValueError( - "data_parallel_external_lb can only " - "be set when data_parallel_size > 1" - ) - if self.distributed_executor_backend == "external_launcher": os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" logger.info("Disabling V1 multiprocessing for external launcher.") - if self.enable_eplb: - if not current_platform.is_cuda(): - raise ValueError( - "Expert parallelism load balancing is only supported on " - "CUDA devices now." - ) - if self.eplb_config.num_redundant_experts < 0: - raise ValueError( - "num_redundant_experts must be non-negative, but got " - f"{self.eplb_config.num_redundant_experts}." - ) - if not self.enable_expert_parallel: - raise ValueError("enable_expert_parallel must be True to use EPLB.") - if self.tensor_parallel_size * self.data_parallel_size <= 1: - raise ValueError( - "EPLB requires tensor_parallel_size or data_parallel_size " - f"to be greater than 1, but got " - f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}." - ) - else: - if self.eplb_config.num_redundant_experts != 0: - raise ValueError( - "num_redundant_experts is set to " - f"{self.eplb_config.num_redundant_experts} but EPLB is not " - "enabled. Either enable EPLB or unset " - "num_redundant_experts." - ) if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. - from vllm.executor import ray_utils + from vllm.v1.executor import ray_utils backend: DistributedExecutorBackend = "mp" ray_found = ray_utils.ray_is_available() @@ -514,11 +553,10 @@ def __post_init__(self) -> None: if self.distributed_executor_backend is None and self.world_size == 1: self.distributed_executor_backend = "uni" - if not -1 <= self._api_process_rank < self._api_process_count: - raise ValueError( - "Invalid value of `_api_process_rank`. " - f"Expected to be `-1` or `[0, {self._api_process_count})`, " - f"but found: {self._api_process_rank}" + if self.max_parallel_loading_workers is not None: + logger.warning( + "max_parallel_loading_workers is currently " + "not supported and will be ignored." ) @property @@ -531,25 +569,28 @@ def use_ray(self) -> bool: @model_validator(mode="after") def _verify_args(self) -> Self: # Lazy import to avoid circular import - from vllm.executor.executor_base import ExecutorBase - from vllm.platforms import current_platform + from vllm.v1.executor import Executor + + # Enable batch invariance settings if requested + if vllm_is_batch_invariant(): + self.disable_custom_all_reduce = True if ( self.distributed_executor_backend is not None and not isinstance(self.distributed_executor_backend, str) and not ( isinstance(self.distributed_executor_backend, type) - and issubclass(self.distributed_executor_backend, ExecutorBase) + and issubclass(self.distributed_executor_backend, Executor) ) ): raise ValueError( "Unrecognized distributed executor backend " f"{self.distributed_executor_backend}. Supported " "values are 'ray', 'mp' 'uni', 'external_launcher', " - " custom ExecutorBase subclass or its import path." + " custom Executor subclass or its import path." ) if self.use_ray: - from vllm.executor import ray_utils + from vllm.v1.executor import ray_utils ray_utils.assert_ray_available() diff --git a/vllm/config/pooler.py b/vllm/config/pooler.py index 8b10992faa02..0590f74aa4c9 100644 --- a/vllm/config/pooler.py +++ b/vllm/config/pooler.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib -from typing import Any, Optional +from typing import Any from pydantic.dataclasses import dataclass @@ -14,23 +14,23 @@ class PoolerConfig: """Controls the behavior of output pooling in pooling models.""" - pooling_type: Optional[str] = None + pooling_type: str | None = None """ The pooling method of the pooling model. This should be a key in [`vllm.model_executor.layers.pooler.PoolingType`][]. """ ## for embeddings models - normalize: Optional[bool] = None + normalize: bool | None = None """ Whether to normalize the embeddings outputs. Defaults to True. """ - dimensions: Optional[int] = None + dimensions: int | None = None """ Reduce the dimensions of embeddings if model support matryoshka representation. Defaults to None. """ - enable_chunked_processing: Optional[bool] = None + enable_chunked_processing: bool | None = None """ Whether to enable chunked processing for long inputs that exceed the model's maximum position embeddings. When enabled, long inputs will be split into @@ -38,7 +38,7 @@ class PoolerConfig: This allows embedding models to handle arbitrarily long text without CUDA errors. Defaults to False. """ - max_embed_len: Optional[int] = None + max_embed_len: int | None = None """ Maximum input length allowed for embedding generation. When set, allows inputs longer than max_embed_len to be accepted for embedding models. @@ -48,33 +48,33 @@ class PoolerConfig: """ ## for classification models - activation: Optional[bool] = None + activation: bool | None = None """ Whether to apply activation function to the classification outputs. Defaults to True. """ - logit_bias: Optional[float] = None + logit_bias: float | None = None """ If provided, apply classification logit biases. Defaults to None. """ ## for reward models - softmax: Optional[bool] = None + softmax: bool | None = None """ Whether to apply softmax to the reward outputs. Defaults to True. """ - step_tag_id: Optional[int] = None + step_tag_id: int | None = None """ - If set, only the score corresponding to the ``step_tag_id`` in the + If set, only the score corresponding to the `step_tag_id` in the generated sentence should be returned. Otherwise, the scores for all tokens are returned. """ - returned_token_ids: Optional[list[int]] = None + returned_token_ids: list[int] | None = None """ A list of indices for the vocabulary dimensions to be extracted, - such as the token IDs of ``good_token`` and ``bad_token`` in the - ``math-shepherd-mistral-7b-prm`` model. + such as the token IDs of `good_token` and `bad_token` in the + `math-shepherd-mistral-7b-prm` model. """ def compute_hash(self) -> str: diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 396258aac287..af47531501cf 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -3,7 +3,7 @@ import hashlib from dataclasses import InitVar, field -from typing import Any, Literal, Union +from typing import Any, Literal from pydantic import SkipValidation, model_validator from pydantic.dataclasses import dataclass @@ -71,14 +71,6 @@ class SchedulerConfig: NOTE: This will be replaced by speculative config in the future; it is present to enable correctness tests until then.""" - cuda_graph_sizes: list[int] = field(default_factory=list) - """Cuda graph capture sizes - 1. if none provided, then default set to [min(max_num_seqs * 2, 512)] - 2. if one value is provided, then the capture list would follow the - pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)] - 3. more than one value (e.g. 1 2 128) is provided, then the capture list - will follow the provided list.""" - enable_chunked_prefill: SkipValidation[bool] = None # type: ignore """If True, prefill requests can be chunked based on the remaining max_num_batched_tokens.""" @@ -107,12 +99,6 @@ class SchedulerConfig: NOTE: This is not currently configurable. It will be overridden by max_num_batched_tokens in case max multimodal embedding size is larger.""" - send_delta_data: bool = False - """Private API. If used, scheduler sends delta data to - workers instead of an entire data. It should be enabled only - when SPMD worker architecture is enabled. I.e., - VLLM_USE_RAY_SPMD_WORKER=1""" - policy: SchedulerPolicy = "fcfs" """The scheduling policy to use:\n - "fcfs" means first come first served, i.e. requests are handled in order @@ -131,12 +117,12 @@ class SchedulerConfig: some image tokens can be scheduled (like TTTTIIIII, leaving IIIII), it will be scheduled as TTTT in one step and IIIIIIIIII in the next.""" - # scheduler class or path. "vllm.core.scheduler.Scheduler" (default) - # or "mod.custom_class". - scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler" - """The scheduler class to use. "vllm.core.scheduler.Scheduler" is the - default scheduler. Can be a class directly or the path to a class of form - "mod.custom_class".""" + # scheduler class or path. "vllm.v1.core.sched.scheduler.Scheduler" + # (default) or "mod.custom_class". + scheduler_cls: str | type[object] = "vllm.v1.core.sched.scheduler.Scheduler" + """The scheduler class to use. "vllm.v1.core.sched.scheduler.Scheduler" is + the default scheduler. Can be a class directly or the path to a class of + form "mod.custom_class".""" disable_hybrid_kv_cache_manager: bool = False """If set to True, KV cache manager will allocate the same size of KV cache @@ -241,13 +227,6 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: self.long_prefill_token_threshold, ) - # NOTE: Default set cuda_graph_sizes to [min(max_num_seqs * 2, 512)]. - # This avoids OOM in tight memory scenarios with small max_num_seqs, - # and prevents capture of many large graphs (>512) that would greatly - # increase startup time with limited performance benefit. - if not self.cuda_graph_sizes: - self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)] - if self.async_scheduling: self.scheduler_cls = "vllm.v1.core.sched.async_scheduler.AsyncScheduler" diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index aa0c07cf62a3..4c7b7369ed4b 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -3,9 +3,9 @@ import ast import hashlib -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal -from pydantic import SkipValidation, model_validator +from pydantic import Field, SkipValidation, model_validator from pydantic.dataclasses import dataclass from typing_extensions import Self @@ -13,7 +13,7 @@ from vllm.config.parallel import ParallelConfig from vllm.config.utils import config from vllm.logger import init_logger -from vllm.utils import LazyLoader +from vllm.utils.import_utils import LazyLoader if TYPE_CHECKING: from transformers import PretrainedConfig @@ -59,16 +59,16 @@ class SpeculativeConfig: """Configuration for speculative decoding.""" - enforce_eager: Optional[bool] = None + enforce_eager: bool | None = None """Override the default enforce_eager from model_config""" # General speculative decoding control - num_speculative_tokens: SkipValidation[int] = None # type: ignore + num_speculative_tokens: int = Field(default=None, gt=0) """The number of speculative tokens, if provided. It will default to the number in the draft model config if present, otherwise, it is required.""" - model: Optional[str] = None + model: str | None = None """The name of the draft model, eagle head, or additional weights, if provided.""" - method: Optional[SpeculativeMethod] = None + method: SpeculativeMethod | None = None """The name of the speculative method to use. If users provide and set the `model` param, the speculative method type will be detected automatically if possible, if `model` param is not provided, the method name must be @@ -76,7 +76,7 @@ class SpeculativeConfig: If using `ngram` method, the related configuration `prompt_lookup_max` and `prompt_lookup_min` should be considered.""" - draft_tensor_parallel_size: Optional[int] = None + draft_tensor_parallel_size: int | None = Field(default=None, ge=1) """The degree of the tensor parallelism for the draft model. Can only be 1 or the same as the target model's tensor parallel size.""" disable_logprobs: bool = True @@ -85,24 +85,24 @@ class SpeculativeConfig: according to the log probability settings in SamplingParams.""" # Draft model configuration - quantization: Optional[me_quant.QuantizationMethods] = None + quantization: me_quant.QuantizationMethods | None = None """Quantization method that was used to quantize the draft model weights. If `None`, we assume the model weights are not quantized. Note that it only takes effect when using the draft model-based speculative method.""" - max_model_len: Optional[int] = None + max_model_len: int | None = Field(default=None, ge=1) """The maximum model length of the draft model. Used when testing the ability to skip speculation for some sequences.""" - revision: Optional[str] = None + revision: str | None = None """The specific model version to use for the draft model. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.""" - code_revision: Optional[str] = None + code_revision: str | None = None """The specific revision to use for the draft model code on Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.""" # Advanced control - disable_by_batch_size: Optional[int] = None + disable_by_batch_size: int | None = Field(default=None, ge=2) """Disable speculative decoding for new incoming requests when the number of enqueued requests is larger than this value, if provided.""" disable_padded_drafter_batch: bool = False @@ -112,14 +112,14 @@ class SpeculativeConfig: only affects the EAGLE method of speculation.""" # Ngram proposer configuration - prompt_lookup_max: Optional[int] = None + prompt_lookup_max: int | None = Field(default=None, ge=1) """Maximum size of ngram token window when using Ngram proposer, required when method is set to ngram.""" - prompt_lookup_min: Optional[int] = None + prompt_lookup_min: int | None = Field(default=None, ge=1) """Minimum size of ngram token window when using Ngram proposer, if provided. Defaults to 1.""" - speculative_token_tree: Optional[str] = None + speculative_token_tree: str | None = None """Specifies the tree structure for speculative token generation. """ # required configuration params passed from engine @@ -232,9 +232,8 @@ def __post_init__(self): if self.model is None and self.num_speculative_tokens is not None: if self.method == "mtp": - assert self.target_model_config is not None, ( - "target_model_config must be present for mtp" - ) + if self.target_model_config is None: + raise ValueError("target_model_config must be present for mtp") if self.target_model_config.hf_text_config.model_type == "deepseek_v32": # FIXME(luccafong): cudgraph with v32 MTP is not supported, # remove this when the issue is fixed. @@ -268,21 +267,21 @@ def __post_init__(self): self.prompt_lookup_min = 5 self.prompt_lookup_max = 5 elif self.prompt_lookup_min is None: - assert self.prompt_lookup_max is not None + if self.prompt_lookup_max is None: + raise ValueError( + "Either prompt_lookup_max or prompt_lookup_min must be " + "provided when using the ngram method." + ) self.prompt_lookup_min = self.prompt_lookup_max elif self.prompt_lookup_max is None: - assert self.prompt_lookup_min is not None + if self.prompt_lookup_min is None: + raise ValueError( + "Either prompt_lookup_max or prompt_lookup_min must be " + "provided when using the ngram method." + ) self.prompt_lookup_max = self.prompt_lookup_min # Validate values - if self.prompt_lookup_min < 1: - raise ValueError( - f"prompt_lookup_min={self.prompt_lookup_min} must be > 0" - ) - if self.prompt_lookup_max < 1: - raise ValueError( - f"prompt_lookup_max={self.prompt_lookup_max} must be > 0" - ) if self.prompt_lookup_min > self.prompt_lookup_max: raise ValueError( f"prompt_lookup_min={self.prompt_lookup_min} must " @@ -446,10 +445,11 @@ def __post_init__(self): self.target_parallel_config, self.draft_tensor_parallel_size ) ) + return self @staticmethod def _maybe_override_draft_max_model_len( - speculative_max_model_len: Optional[int], + speculative_max_model_len: int | None, draft_max_model_len: int, target_max_model_len: int, ) -> int: @@ -488,7 +488,7 @@ def _maybe_override_draft_max_model_len( @staticmethod def _verify_and_get_draft_tp( target_parallel_config: ParallelConfig, - speculative_draft_tensor_parallel_size: Optional[int], + speculative_draft_tensor_parallel_size: int | None, draft_hf_config: PretrainedConfig, ) -> int: """ diff --git a/vllm/config/speech_to_text.py b/vllm/config/speech_to_text.py index de9f525efe18..3eafff1a3060 100644 --- a/vllm/config/speech_to_text.py +++ b/vllm/config/speech_to_text.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional from pydantic.dataclasses import dataclass @@ -28,7 +27,7 @@ class SpeechToTextConfig: splitting long audio. This helps maintain context across chunk boundaries and improves transcription quality at split points.""" - min_energy_split_window_size: Optional[int] = 1600 + min_energy_split_window_size: int | None = 1600 """Window size in samples for finding low-energy (quiet) regions to split audio chunks. The algorithm looks for the quietest moment within this window to minimize cutting through speech. Default 1600 samples ≈ 100ms diff --git a/vllm/config/structured_outputs.py b/vllm/config/structured_outputs.py index 5111c9c77d90..76b565006e28 100644 --- a/vllm/config/structured_outputs.py +++ b/vllm/config/structured_outputs.py @@ -35,6 +35,8 @@ class StructuredOutputsConfig: reasoning_parser: str = "" """Select the reasoning parser depending on the model that you're using. This is used to parse the reasoning content into OpenAI API format.""" + enable_in_reasoning: bool = False + """Whether to use structured input for reasoning.""" def compute_hash(self) -> str: """ diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 889ebf45b12d..5e7e7580c5a9 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -7,9 +7,11 @@ import textwrap from collections.abc import Iterable from dataclasses import MISSING, Field, field, fields, is_dataclass, replace +from itertools import pairwise from typing import TYPE_CHECKING, Any, Protocol, TypeVar import regex as re +from pydantic.fields import FieldInfo from typing_extensions import runtime_checkable if TYPE_CHECKING: @@ -49,7 +51,14 @@ def get_field(cls: ConfigType, name: str) -> Field: if (default_factory := named_field.default_factory) is not MISSING: return field(default_factory=default_factory) if (default := named_field.default) is not MISSING: + if isinstance(default, FieldInfo): + # Handle pydantic.Field defaults + if default.default_factory is not None: + return field(default_factory=default.default_factory) + else: + default = default.default return field(default=default) + raise ValueError( f"{cls.__name__}.{name} must have a default value or default factory." ) @@ -102,30 +111,7 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]: https://davidism.com/mit-license/ """ - def pairwise(iterable): - """ - Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise - - Can be removed when Python 3.9 support is dropped. - """ - iterator = iter(iterable) - a = next(iterator, None) - - for b in iterator: - yield a, b - a = b - - try: - cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] - except (OSError, KeyError, TypeError): - # HACK: Python 3.13+ workaround - set missing __firstlineno__ - # Workaround can be removed after we upgrade to pydantic==2.12.0 - with open(inspect.getfile(cls)) as f: - for i, line in enumerate(f): - if f"class {cls.__name__}" in line and ":" in line: - cls.__firstlineno__ = i + 1 - break - cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] + cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] if not isinstance(cls_node, ast.ClassDef): raise TypeError("Given object was not a class.") diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 37b8c3fe6677..916f258d6586 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -5,14 +5,15 @@ import hashlib import json import os +import time from contextlib import contextmanager -from dataclasses import field, replace +from dataclasses import replace from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, TypeVar import torch -from pydantic import ConfigDict +from pydantic import ConfigDict, Field from pydantic.dataclasses import dataclass import vllm.envs as envs @@ -21,7 +22,7 @@ from vllm.utils import random_uuid from .cache import CacheConfig -from .compilation import CompilationConfig, CompilationLevel, CUDAGraphMode +from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode from .device import DeviceConfig from .kv_events import KVEventsConfig from .kv_transfer import KVTransferConfig @@ -39,11 +40,14 @@ from transformers import PretrainedConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + from vllm.v1.kv_cache_interface import KVCacheConfig else: PretrainedConfig = Any QuantizationConfig = Any + KVCacheConfig = Any + logger = init_logger(__name__) @@ -56,53 +60,47 @@ class VllmConfig: # TODO: use default_factory once default constructing ModelConfig doesn't # try to download a model - model_config: ModelConfig = None # type: ignore + model_config: ModelConfig = Field(default=None) """Model configuration.""" - cache_config: CacheConfig = field(default_factory=CacheConfig) + cache_config: CacheConfig = Field(default_factory=CacheConfig) """Cache configuration.""" - parallel_config: ParallelConfig = field(default_factory=ParallelConfig) + parallel_config: ParallelConfig = Field(default_factory=ParallelConfig) """Parallel configuration.""" - scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig) + scheduler_config: SchedulerConfig = Field(default_factory=SchedulerConfig) """Scheduler configuration.""" - device_config: DeviceConfig = field(default_factory=DeviceConfig) + device_config: DeviceConfig = Field(default_factory=DeviceConfig) """Device configuration.""" - load_config: LoadConfig = field(default_factory=LoadConfig) + load_config: LoadConfig = Field(default_factory=LoadConfig) """Load configuration.""" - lora_config: Optional[LoRAConfig] = None + lora_config: LoRAConfig | None = None """LoRA configuration.""" - speculative_config: Optional[SpeculativeConfig] = None + speculative_config: SpeculativeConfig | None = None """Speculative decoding configuration.""" - structured_outputs_config: StructuredOutputsConfig = field( + structured_outputs_config: StructuredOutputsConfig = Field( default_factory=StructuredOutputsConfig ) """Structured outputs configuration.""" - observability_config: Optional[ObservabilityConfig] = None + observability_config: ObservabilityConfig | None = None """Observability configuration.""" - quant_config: Optional[QuantizationConfig] = None + quant_config: QuantizationConfig | None = None """Quantization configuration.""" - compilation_config: CompilationConfig = field(default_factory=CompilationConfig) + compilation_config: CompilationConfig = Field(default_factory=CompilationConfig) """`torch.compile` and cudagraph capture configuration for the model. - As a shorthand, `-O<n>` can be used to directly specify the compilation - level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`). - Currently, -O <n> and -O=<n> are supported as well but this will likely be - removed in favor of clearer -O<n> syntax in the future. - - NOTE: level 0 is the default level without any optimization. level 1 and 2 - are for internal testing only. level 3 is the recommended level for - production, also default in V1. + As a shorthand, one can append compilation arguments via + -0.parameter=arguement such as `-O.mode=3` (same as `-O='{"mode":3}'`). You can specify the full compilation config like so: - `{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` + `{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` """ - kv_transfer_config: Optional[KVTransferConfig] = None + kv_transfer_config: KVTransferConfig | None = None """The configurations for distributed KV cache transfer.""" - kv_events_config: Optional[KVEventsConfig] = None + kv_events_config: KVEventsConfig | None = None """The configurations for event publishing.""" # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. - additional_config: Union[dict, SupportsHash] = field(default_factory=dict) + additional_config: dict | SupportsHash = Field(default_factory=dict) """Additional config for specified platform. Different platforms may support different configs. Make sure the configs are valid for the platform you are using. Contents must be hashable.""" @@ -202,16 +200,16 @@ def compute_hash(self) -> str: return hash_str def pad_for_cudagraph(self, batch_size: int) -> int: - # if batch_size > self.compilation_config.max_capture_size, + # if batch_size > self.compilation_config.max_cudagraph_capture_size, # it should raise an IndexError. # the caller should make sure the batch_size is within the range, - # i.e., batch_size <= self.compilation_config.max_capture_size + # i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size return self.compilation_config.bs_to_padded_graph_size[batch_size] @staticmethod def _get_quantization_config( model_config: ModelConfig, load_config: LoadConfig - ) -> Optional[QuantizationConfig]: + ) -> QuantizationConfig | None: """Get the quantization config.""" from vllm.platforms import current_platform @@ -244,7 +242,7 @@ def _get_quantization_config( @staticmethod def get_quantization_config( model_config: ModelConfig, load_config: LoadConfig - ) -> Optional[QuantizationConfig]: + ) -> QuantizationConfig | None: import copy # For some reason, the _ version of this modifies the model_config @@ -256,7 +254,7 @@ def get_quantization_config( def with_hf_config( self, hf_config: PretrainedConfig, - architectures: Optional[list[str]] = None, + architectures: list[str] | None = None, ) -> "VllmConfig": if architectures is not None: hf_config = copy.deepcopy(hf_config) @@ -270,6 +268,9 @@ def with_hf_config( def __post_init__(self): """Verify configs are valid & consistent with each other.""" + # To give each torch profile run a unique instance name. + self.instance_id = f"{time.time_ns()}" + self.try_verify_and_update_config() if self.model_config is not None: @@ -301,38 +302,33 @@ def __post_init__(self): "precision for chunked prefill triton kernels." ) - # If the user does not explicitly set a compilation level, then - # we use the default level. The default level depends on other + # If the user does not explicitly set a compilation mode, then + # we use the default mode. The default mode depends on other # settings (see the below code). - if self.compilation_config.level is None: + if self.compilation_config.mode is None: if envs.VLLM_USE_V1: if ( self.model_config is not None and not self.model_config.enforce_eager ): - self.compilation_config.level = CompilationLevel.PIECEWISE + self.compilation_config.mode = CompilationMode.VLLM_COMPILE else: - self.compilation_config.level = CompilationLevel.NO_COMPILATION + self.compilation_config.mode = CompilationMode.NONE else: - # NB: Passing both --enforce-eager and a compilation level - # in V0 means the compilation level wins out. - self.compilation_config.level = CompilationLevel.NO_COMPILATION + # NB: Passing both --enforce-eager and a compilation mode + # in V0 means the compilation mode wins out. + self.compilation_config.mode = CompilationMode.NONE else: - assert self.compilation_config.level >= CompilationLevel.NO_COMPILATION - assert self.compilation_config.level <= CompilationLevel.PIECEWISE - assert self.compilation_config.level <= 3 + assert self.compilation_config.mode >= CompilationMode.NONE + assert self.compilation_config.mode <= CompilationMode.VLLM_COMPILE # If user does not set custom ops via none or all set it here based on - # compilation level and backend. - if ( - self.compilation_config.custom_ops.count("none") - + self.compilation_config.custom_ops.count("all") - == 0 - ): + # compilation mode and backend. + if all(s not in self.compilation_config.custom_ops for s in ("all", "none")): if ( - self.compilation_config.level > 0 - and self.compilation_config.backend != "eager" + self.compilation_config.backend == "inductor" + and self.compilation_config.mode > CompilationMode.NONE ): self.compilation_config.custom_ops.append("none") else: @@ -351,27 +347,61 @@ def __post_init__(self): if self.compilation_config.cudagraph_mode is None: if ( envs.VLLM_USE_V1 - and self.compilation_config.level == CompilationLevel.PIECEWISE + and self.compilation_config.mode == CompilationMode.VLLM_COMPILE ): # default to full and piecewise for most models self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_AND_PIECEWISE ) + else: + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE - # pooling models and encoder-decoder models - # do not support full cudagraphs - if self.model_config is not None and ( - self.model_config.pooler_config is not None - or self.model_config.is_encoder_decoder + # if cudagraph_mode has full cudagraphs, we need to check support + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): + # decode context parallel does not support full cudagraphs + if self.parallel_config.decode_context_parallel_size > 1: + logger.warning_once( + "Decode context parallel (DCP) is enabled, which is " + "incompatible with full CUDA graphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + elif self.model_config is not None: + if self.model_config.pooler_config is not None: + logger.warning_once( + "Pooling models do not support full cudagraphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + elif self.model_config.is_encoder_decoder: + logger.warning_once( + "Encoder-decoder models do not support full cudagraphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + elif ( + current_platform.is_cuda() + and current_platform.is_device_capability(100) + and self.model_config.max_model_len > 131072 + and not self.model_config.use_mla ): + # Refer to vllm/utils/flashinfer.py::use_trtllm_attention() + logger.warning_once( + "NVIDIA Blackwell TRTLLM attention cannot support " + "max_model_len >= 131072 (found " + f"{self.model_config.max_model_len}), causing dynamic " + "dispatching that breaks full cudagraphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - else: - self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE # disable cudagraph when enforce eager execution if self.model_config is not None and self.model_config.enforce_eager: logger.info("Cudagraph is disabled under eager mode") self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + # override related settings when enforce eager + self.compilation_config.max_cudagraph_capture_size = 0 + self.compilation_config.cudagraph_capture_sizes = [] elif envs.VLLM_USE_V1: self.compilation_config.cudagraph_num_of_warmups = 1 @@ -478,10 +508,10 @@ def __post_init__(self): ) current_platform.check_and_update_config(self) - # Do this after all the updates to compilation_config.level + # Do this after all the updates to compilation_config.mode if ( envs.VLLM_USE_V1 - and self.compilation_config.level == CompilationLevel.PIECEWISE + and self.compilation_config.mode == CompilationMode.VLLM_COMPILE ): self.compilation_config.set_splitting_ops_for_v1() @@ -500,8 +530,8 @@ def __post_init__(self): ) if self.compilation_config.cudagraph_mode.requires_piecewise_compilation(): - assert self.compilation_config.level == CompilationLevel.PIECEWISE, ( - "Compilation level should be CompilationLevel.PIECEWISE " + assert self.compilation_config.mode == CompilationMode.VLLM_COMPILE, ( + "Compilation mode should be CompilationMode.VLLM_COMPILE " "when cudagraph_mode piecewise cudagraphs is used, " f"cudagraph_mode={self.compilation_config.cudagraph_mode}" ) @@ -515,13 +545,13 @@ def __post_init__(self): ) if self.parallel_config.enable_dbo: - a2a_backend = envs.VLLM_ALL2ALL_BACKEND + a2a_backend = self.parallel_config.all2all_backend assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], ( "Microbatching currently only supports the deepep_low_latency and " f"deepep_high_throughput all2all backend. {a2a_backend} is not " - "supported. To fix set the VLLM_ALL2ALL_BACKEND environment " - "variable to deepep_low_latency or deepep_high_throughput and " - "install the DeepEP kernels." + "supported. To fix use --all2all-backend=deepep_low_latency or " + "--all2all-backend=deepep_high_throughput and install the DeepEP" + " kernels." ) if not self.model_config.disable_cascade_attn: @@ -541,9 +571,6 @@ def __post_init__(self): if not current_platform.support_hybrid_kv_cache(): # Hybrid KV cache manager is not supported on non-GPU platforms. self.scheduler_config.disable_hybrid_kv_cache_manager = True - if self.kv_transfer_config is not None: - # Hybrid KV cache manager is not compatible with KV transfer. - self.scheduler_config.disable_hybrid_kv_cache_manager = True if self.kv_events_config is not None: # Hybrid KV cache manager is not compatible with KV events. self.scheduler_config.disable_hybrid_kv_cache_manager = True @@ -597,7 +624,7 @@ def has_blocked_weights(): # https://github.com/vllm-project/vllm/issues/25094 if has_blocked_weights(): custom_ops = self.compilation_config.custom_ops - if "none" not in custom_ops and "-quant_fp8" not in custom_ops: + if "-quant_fp8" not in custom_ops: custom_ops.append("+quant_fp8") def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: @@ -630,11 +657,13 @@ def _set_cudagraph_sizes(self): ```python max_graph_size = min(max_num_seqs * 2, 512) - # 1, 2, 4, then multiples of 8 up to max_graph_size - cuda_graph_sizes = [1, 2, 4, 8, 16, 24, 32, 40, ..., max_graph_size] + # 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16 + # up to max_graph_size + cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list( + range(256, max_graph_size + 1, 16)) In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` - will be the final sizes to capture cudagraph (in descending order). + will be the final sizes to capture cudagraph (in ascending order). These sizes are used to capture and reuse CUDA graphs for performance-critical paths (e.g., decoding). Capturing enables @@ -661,35 +690,111 @@ def _set_cudagraph_sizes(self): not be used. """ - # calculate the default `batch_size_capture_list` - batch_size_capture_list = [] - if self.model_config is not None and not self.model_config.enforce_eager: - cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes - if len(cuda_graph_sizes) == 1: - max_graph_size = cuda_graph_sizes[0] - assert max_graph_size >= 1, ( - "Maximum cudagraph size should be greater than or equal to 1." + if ( + self.model_config is not None + and not self.model_config.enforce_eager + and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ): + # determine the initial max_cudagraph_capture_size + max_cudagraph_capture_size = ( + self.compilation_config.max_cudagraph_capture_size + ) + if max_cudagraph_capture_size is None: + max_cudagraph_capture_size = min( + self.scheduler_config.max_num_seqs * 2, 512 + ) + max_num_tokens = self.scheduler_config.max_num_batched_tokens + max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size) + + assert max_cudagraph_capture_size >= 1, ( + "Maximum cudagraph size should be greater than or equal to 1 " + "when using cuda graph." + ) + + # determine the cudagraph_capture_sizes + if self.compilation_config.cudagraph_capture_sizes is not None: + assert len(self.compilation_config.cudagraph_capture_sizes) > 0, ( + "cudagraph_capture_sizes should contain at least one element " + "when using cuda graph." ) - batch_size_capture_list = [ - i for i in [1, 2, 4] if i <= max_graph_size - ] + list(range(8, max_graph_size + 1, 8)) - elif len(cuda_graph_sizes) > 1: - batch_size_capture_list = sorted(cuda_graph_sizes) + # de-duplicate the sizes provided by the config + dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes)) + cudagraph_capture_sizes = dedup_sizes + # sort to make sure the sizes are in ascending order + cudagraph_capture_sizes.sort() else: - raise TypeError(f"Invalid value for {cuda_graph_sizes=}.") + cudagraph_capture_sizes = [ + i for i in [1, 2, 4] if i <= max_cudagraph_capture_size + ] + if max_cudagraph_capture_size >= 8: + # Step size 8 for small batch sizes, up to 256(not included) + cudagraph_capture_sizes += list( + range(8, min(max_cudagraph_capture_size + 1, 256), 8) + ) + if max_cudagraph_capture_size >= 256: + # Step size 16 for larger batch sizes + cudagraph_capture_sizes += list( + range(256, max_cudagraph_capture_size + 1, 16) + ) + if ( self.parallel_config.tensor_parallel_size > 1 and self.compilation_config.pass_config.enable_sequence_parallelism ): - batch_size_capture_list = self.update_sizes_for_sequence_parallelism( - batch_size_capture_list + cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism( + cudagraph_capture_sizes ) - max_num_tokens = self.scheduler_config.max_num_batched_tokens - batch_size_capture_list = [ - size for size in batch_size_capture_list if size <= max_num_tokens - ] - self.compilation_config.init_with_cudagraph_sizes(batch_size_capture_list) + # user-specific compilation_config.max_cudagraph_capture_size get + # truncated to valid_max_size when they are inconsistent. + valid_max_size = ( + cudagraph_capture_sizes[-1] if cudagraph_capture_sizes else 0 + ) + if ( + self.compilation_config.max_cudagraph_capture_size is not None + and self.compilation_config.max_cudagraph_capture_size != valid_max_size + ): + # raise error only when both two flags are user-specified + # and they are inconsistent with each other + if self.compilation_config.cudagraph_capture_sizes is not None: + raise ValueError( + "customized max_cudagraph_capture_size" + f"(={self.compilation_config.max_cudagraph_capture_size}) " + "should be consistent with the max value of " + f"cudagraph_capture_sizes(={valid_max_size})" + ) + + logger.warning( + "Truncating max_cudagraph_capture_size to %d", + valid_max_size, + ) + # always set the final max_cudagraph_capture_size + self.compilation_config.max_cudagraph_capture_size = valid_max_size + + if self.compilation_config.cudagraph_capture_sizes is not None and len( + cudagraph_capture_sizes + ) < len(self.compilation_config.cudagraph_capture_sizes): + # If users have specified capture sizes, we only need to + # compare the lens before and after modification since the modified + # list is only the subset of the original list. + logger.warning( + ( + "cudagraph_capture_sizes specified in compilation_config" + " %s is overridden by config %s" + ), + self.compilation_config.cudagraph_capture_sizes, + cudagraph_capture_sizes, + ) + # always write back the final sizes + self.compilation_config.cudagraph_capture_sizes = cudagraph_capture_sizes + + else: + # no cudagraph in use + self.compilation_config.max_cudagraph_capture_size = 0 + self.compilation_config.cudagraph_capture_sizes = [] + + # complete the remaining process. + self.compilation_config.post_init_cudagraph_sizes() def recalculate_max_model_len(self, max_model_len: int): # Can only be called in try_verify_and_update_config @@ -738,15 +843,18 @@ def try_verify_and_update_config(self): "Overriding `load_format` to 'runai_streamer'" ) self.load_config.load_format = "runai_streamer" - elif self.load_config.load_format != "runai_streamer": + elif self.load_config.load_format not in ( + "runai_streamer", + "runai_streamer_sharded", + ): raise ValueError( f"To load a model from S3, 'load_format' " - f"must be 'runai_streamer', " + f"must be 'runai_streamer' or 'runai_streamer_sharded', " f"but got '{self.load_config.load_format}'. " f"Model: {self.model_config.model}" ) - def compile_debug_dump_path(self) -> Optional[Path]: + def compile_debug_dump_path(self) -> Path | None: """Returns a rank-aware path for dumping torch.compile debug information. """ @@ -796,13 +904,13 @@ def __str__(self): ) -_current_vllm_config: Optional[VllmConfig] = None -_current_prefix: Optional[str] = None +_current_vllm_config: VllmConfig | None = None +_current_prefix: str | None = None @contextmanager def set_current_vllm_config( - vllm_config: VllmConfig, check_compile=False, prefix: Optional[str] = None + vllm_config: VllmConfig, check_compile=False, prefix: str | None = None ): """ Temporarily set the current vLLM config. @@ -829,7 +937,7 @@ def set_current_vllm_config( if ( check_compile - and vllm_config.compilation_config.level == CompilationLevel.PIECEWISE + and vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE and compilation_counter.num_models_seen == num_models_seen ): # If the model supports compilation, @@ -872,7 +980,7 @@ def get_current_vllm_config() -> VllmConfig: def get_layers_from_vllm_config( vllm_config: VllmConfig, layer_type: type[T], - layer_names: Optional[list[str]] = None, + layer_names: list[str] | None = None, ) -> dict[str, T]: """ Get layers from the vLLM config. diff --git a/vllm/connections.py b/vllm/connections.py index 8d5e0e5cbf5d..31b0d5e9c702 100644 --- a/vllm/connections.py +++ b/vllm/connections.py @@ -3,7 +3,6 @@ from collections.abc import Mapping, MutableMapping from pathlib import Path -from typing import Optional from urllib.parse import urlparse import aiohttp @@ -20,8 +19,8 @@ def __init__(self, *, reuse_client: bool = True) -> None: self.reuse_client = reuse_client - self._sync_client: Optional[requests.Session] = None - self._async_client: Optional[aiohttp.ClientSession] = None + self._sync_client: requests.Session | None = None + self._async_client: aiohttp.ClientSession | None = None def get_sync_client(self) -> requests.Session: if self._sync_client is None or not self.reuse_client: @@ -53,8 +52,8 @@ def get_response( url: str, *, stream: bool = False, - timeout: Optional[float] = None, - extra_headers: Optional[Mapping[str, str]] = None, + timeout: float | None = None, + extra_headers: Mapping[str, str] | None = None, allow_redirects: bool = True, ): self._validate_http_url(url) @@ -74,8 +73,8 @@ async def get_async_response( self, url: str, *, - timeout: Optional[float] = None, - extra_headers: Optional[Mapping[str, str]] = None, + timeout: float | None = None, + extra_headers: Mapping[str, str] | None = None, allow_redirects: bool = True, ): self._validate_http_url(url) @@ -91,7 +90,7 @@ async def get_async_response( ) def get_bytes( - self, url: str, *, timeout: Optional[float] = None, allow_redirects: bool = True + self, url: str, *, timeout: float | None = None, allow_redirects: bool = True ) -> bytes: with self.get_response( url, timeout=timeout, allow_redirects=allow_redirects @@ -104,7 +103,7 @@ async def async_get_bytes( self, url: str, *, - timeout: Optional[float] = None, + timeout: float | None = None, allow_redirects: bool = True, ) -> bytes: async with await self.get_async_response( @@ -114,7 +113,7 @@ async def async_get_bytes( return await r.read() - def get_text(self, url: str, *, timeout: Optional[float] = None) -> str: + def get_text(self, url: str, *, timeout: float | None = None) -> str: with self.get_response(url, timeout=timeout) as r: r.raise_for_status() @@ -124,14 +123,14 @@ async def async_get_text( self, url: str, *, - timeout: Optional[float] = None, + timeout: float | None = None, ) -> str: async with await self.get_async_response(url, timeout=timeout) as r: r.raise_for_status() return await r.text() - def get_json(self, url: str, *, timeout: Optional[float] = None) -> str: + def get_json(self, url: str, *, timeout: float | None = None) -> str: with self.get_response(url, timeout=timeout) as r: r.raise_for_status() @@ -141,7 +140,7 @@ async def async_get_json( self, url: str, *, - timeout: Optional[float] = None, + timeout: float | None = None, ) -> str: async with await self.get_async_response(url, timeout=timeout) as r: r.raise_for_status() @@ -153,7 +152,7 @@ def download_file( url: str, save_path: Path, *, - timeout: Optional[float] = None, + timeout: float | None = None, chunk_size: int = 128, ) -> Path: with self.get_response(url, timeout=timeout) as r: @@ -170,7 +169,7 @@ async def async_download_file( url: str, save_path: Path, *, - timeout: Optional[float] = None, + timeout: float | None = None, chunk_size: int = 128, ) -> Path: async with await self.get_async_response(url, timeout=timeout) as r: diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 97c6654385b3..5e3dbde393be 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -11,18 +11,19 @@ import dataclasses import gc import os +from collections.abc import Callable from contextlib import contextmanager -from typing import Any, Callable, Optional, Union +from typing import Any import torch from vllm.logger import init_logger -from vllm.utils import is_pin_memory_available +from vllm.utils.platform_utils import is_pin_memory_available logger = init_logger(__name__) -def find_loaded_library(lib_name) -> Optional[str]: +def find_loaded_library(lib_name) -> str | None: """ According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, the file `/proc/self/maps` contains the memory maps of the process, which includes the @@ -78,7 +79,7 @@ def find_loaded_library(lib_name) -> Optional[str]: class AllocationData: handle: HandleType tag: str - cpu_backup_tensor: Optional[torch.Tensor] = None + cpu_backup_tensor: torch.Tensor | None = None def create_and_map(allocation_handle: HandleType) -> None: @@ -197,7 +198,7 @@ def _python_free_callback(self, ptr: int) -> HandleType: ) return data.handle - def sleep(self, offload_tags: Optional[Union[tuple[str, ...], str]] = None) -> None: + def sleep(self, offload_tags: tuple[str, ...] | str | None = None) -> None: """ Put the allocator in sleep mode. All data in the memory allocation with the specified tag will be @@ -247,7 +248,7 @@ def sleep(self, offload_tags: Optional[Union[tuple[str, ...], str]] = None) -> N gc.collect() torch.cuda.empty_cache() - def wake_up(self, tags: Optional[list[str]] = None) -> None: + def wake_up(self, tags: list[str] | None = None) -> None: """ Wake up the allocator from sleep mode. All data that is previously offloaded will be loaded back to GPU @@ -272,7 +273,7 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None: data.cpu_backup_tensor = None @contextmanager - def use_memory_pool(self, tag: Optional[str] = None): + def use_memory_pool(self, tag: str | None = None): """ A context manager to use the memory pool. All memory allocation created inside the context will be allocated diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 46a735f22ed8..5ad99e4e1592 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional, Union +from typing import Any import torch import torch.distributed @@ -30,13 +30,13 @@ def tensor_model_parallel_reduce_scatter( def tensor_model_parallel_gather( input_: torch.Tensor, dst: int = 0, dim: int = -1 -) -> Optional[torch.Tensor]: +) -> torch.Tensor | None: """Gather the input tensor across model parallel group.""" return get_tp_group().gather(input_, dst, dim) def broadcast_tensor_dict( - tensor_dict: Optional[dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0 + tensor_dict: dict[Any, torch.Tensor | Any] | None = None, src: int = 0 ): if not torch.distributed.is_initialized(): return tensor_dict diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index a22f43cd88d1..013ef3c1f5c3 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Any import torch import torch.distributed as dist @@ -9,15 +9,17 @@ from vllm.distributed import get_dp_group, get_ep_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.utils import has_deep_ep, has_pplx from vllm.utils.flashinfer import has_flashinfer_all2all +from vllm.utils.import_utils import has_deep_ep, has_pplx from .base_device_communicator import All2AllManagerBase, Cache if has_flashinfer_all2all(): - from flashinfer.comm import Mapping - from flashinfer.comm.mnnvl import MnnvlConfig - from flashinfer.comm.trtllm_alltoall import MnnvlMoe + from flashinfer.comm import Mapping # type: ignore[import-not-found] + from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found] + from flashinfer.comm.trtllm_alltoall import ( + MnnvlMoe, # type: ignore[import-not-found] + ) logger = init_logger(__name__) @@ -65,6 +67,7 @@ def dispatch( ) -> tuple[torch.Tensor, torch.Tensor]: sp_size = self.tp_group.world_size if is_sequence_parallel else 1 dp_metadata = get_forward_context().dp_metadata + assert dp_metadata is not None cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size) hidden_states = self.naive_multicast( @@ -81,6 +84,7 @@ def combine( ep_rank = self.rank if is_sequence_parallel else self.dp_rank dp_metadata = get_forward_context().dp_metadata + assert dp_metadata is not None sp_size = self.tp_group.world_size if is_sequence_parallel else 1 cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size) @@ -113,7 +117,10 @@ def dispatch( """ Gather hidden_states and router_logits from all dp ranks. """ - sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() + dp_metadata = get_forward_context().dp_metadata + assert dp_metadata is not None + sizes = dp_metadata.get_chunk_sizes_across_dp_rank() + assert sizes is not None dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() assert sizes[dist_group.rank_in_group] == hidden_states.shape[0] @@ -130,7 +137,10 @@ def combine( """ Reduce-scatter hidden_states across all dp ranks. """ - sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() + dp_metadata = get_forward_context().dp_metadata + assert dp_metadata is not None + sizes = dp_metadata.get_chunk_sizes_across_dp_rank() + assert sizes is not None dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes) @@ -155,7 +165,7 @@ def __init__(self, cpu_group): if self.internode: # inter-node communication needs nvshmem, # intra-node communication uses p2p mapping directly - from pplx_kernels.nvshmem import ( + from pplx_kernels.nvshmem import ( # type: ignore[import-not-found] nvshmem_alloc_empty_unique_id, nvshmem_get_unique_id, nvshmem_init, @@ -182,7 +192,7 @@ def __init__(self, cpu_group): self.handle_cache = Cache() def get_handle(self, kwargs): - import pplx_kernels as pplx + import pplx_kernels as pplx # type: ignore[import-not-found] return self.handle_cache.get_or_create( kwargs, @@ -208,7 +218,9 @@ def destroy(self): handle.destroy() if self.internode: - from pplx_kernels.nvshmem import nvshmem_finalize + from pplx_kernels.nvshmem import ( + nvshmem_finalize, # type: ignore[import-not-found] + ) logger.debug("PPLX NVSHMEM finalize") nvshmem_finalize() @@ -265,7 +277,7 @@ def _make_all2all_kwargs(self) -> dict[Any, Any]: num_rdma_bytes = None num_qps_per_rank = None - if self.internode: + if self.internode and not envs.VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE: num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024 num_qps_per_rank = self.num_sms // 2 else: @@ -288,7 +300,7 @@ def get_handle(self, kwargs): "args are computed in the Manager itself." ) - import deep_ep + import deep_ep # type: ignore[import-not-found] buffer_kwargs = self._make_all2all_kwargs() logger.debug("DeepEP all2all args %s", buffer_kwargs) @@ -298,7 +310,7 @@ def get_handle(self, kwargs): return handle def set_num_sms(self, num_sms: int): - import deep_ep + import deep_ep # type: ignore[import-not-found] # Right now the buffers are sized for only what the kernels were # created with. So we can only reduce the number of SMS used @@ -332,7 +344,7 @@ def _make_all2all_kwargs( num_global_experts: Number of experts in the model. num_local_experts: Number of experts in an EP rank. """ - import deep_ep + import deep_ep # type: ignore[import-not-found] # Defaults for internode and intranode are taken from DeepEP tests. num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024 @@ -351,6 +363,8 @@ def _make_all2all_kwargs( num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=num_qps_per_rank, + allow_nvlink_for_low_latency_mode=envs.VLLM_DEEPEP_LOW_LATENCY_ALLOW_NVLINK, + allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL, ) def get_handle(self, kwargs): @@ -358,7 +372,7 @@ def get_handle(self, kwargs): The kwargs for DeepEPLLAll2AllManager is dictated by _make_all2all_kwargs. """ - import deep_ep + import deep_ep # type: ignore[import-not-found] buffer_kwargs = self._make_all2all_kwargs(**kwargs) logger.debug("DeepEP all2all args %s", buffer_kwargs) @@ -368,7 +382,7 @@ def get_handle(self, kwargs): return handle # DeepEP LL uses RDMA so no SMs are used for communication - def max_sms_used(self) -> Optional[int]: + def max_sms_used(self) -> int | None: return 0 @@ -377,6 +391,11 @@ class FlashInferAllToAllManager(All2AllManagerBase): All2All communication based on flashinfer kernels. """ + # This type lint could be removed after all of the work in + # https://github.com/vllm-project/vllm/issues/26533 done. + rank: int + world_size: int + def __init__(self, cpu_group): assert has_flashinfer_all2all(), ( "flashinfer all2all module not found. Please install/check flashinfer" diff --git a/vllm/distributed/device_communicators/all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py index dabb48320be4..ff2d7436b270 100644 --- a/vllm/distributed/device_communicators/all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -10,7 +10,7 @@ import tempfile from collections.abc import Sequence from itertools import product -from typing import Any, Optional +from typing import Any import torch import torch.distributed as dist @@ -19,7 +19,11 @@ import vllm.envs as envs from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.logger import init_logger -from vllm.utils import cuda_device_count_stateless, update_environment_variables +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) +from vllm.utils.system_utils import update_environment_variables +from vllm.utils.torch_utils import cuda_device_count_stateless logger = init_logger(__name__) @@ -71,6 +75,9 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) is_symmetric_memory_enabled, ) + if vllm_is_batch_invariant(): + return False + if not is_symmetric_memory_enabled(): return False if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]: @@ -86,7 +93,7 @@ def producer( producer_queue, consumer_queue, result_queue, - cuda_visible_devices: Optional[str] = None, + cuda_visible_devices: str | None = None, ): if cuda_visible_devices is not None: update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) @@ -120,7 +127,7 @@ def consumer( producer_queue, consumer_queue, result_queue, - cuda_visible_devices: Optional[str] = None, + cuda_visible_devices: str | None = None, ): if cuda_visible_devices is not None: update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) @@ -253,7 +260,7 @@ def can_actually_p2p( # e.g. used by different vllm engines. The device id in the cache file is a # **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number # of visible devices in the vllm engine. -_gpu_p2p_access_cache: Optional[dict[str, bool]] = None +_gpu_p2p_access_cache: dict[str, bool] | None = None def gpu_p2p_access_check(src: int, tgt: int) -> bool: diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index c32be0bec55c..9566dbac7f22 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import threading -from typing import Optional, Union from weakref import WeakValueDictionary import torch @@ -75,7 +74,7 @@ def dispatch( def set_num_sms(self, num_sms: int): pass - def max_sms_used(self) -> Optional[int]: + def max_sms_used(self) -> int | None: return None # None means it could use the whole GPU def combine(self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False): @@ -96,8 +95,8 @@ class DeviceCommunicatorBase: def __init__( self, cpu_group: ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[ProcessGroup] = None, + device: torch.device | None = None, + device_group: ProcessGroup | None = None, unique_name: str = "", ): self.device = device or torch.device("cpu") @@ -112,6 +111,7 @@ def __init__( self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank) use_ep = False + all2all_backend = None from vllm.config import get_current_vllm_config config = get_current_vllm_config() @@ -120,10 +120,12 @@ def __init__( # where all data parallel ranks execute forward together), # we initialize the all2all manager used in expert parallel. use_ep = config.parallel_config.data_parallel_size > 1 + all2all_backend = config.parallel_config.all2all_backend self.is_ep_communicator = "ep" in unique_name self.use_all2all = self.is_ep_communicator and use_ep - self.all2all_manager: Optional[All2AllManagerBase] = None + self.all2all_backend = all2all_backend + self.all2all_manager: All2AllManagerBase | None = None def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) @@ -156,10 +158,10 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: def all_gatherv( self, - input_: Union[torch.Tensor, list[torch.Tensor]], + input_: torch.Tensor | list[torch.Tensor], dim: int = 0, - sizes: Optional[list[int]] = None, - ) -> Union[torch.Tensor, list[torch.Tensor]]: + sizes: list[int] | None = None, + ) -> torch.Tensor | list[torch.Tensor]: raise NotImplementedError def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: @@ -196,13 +198,13 @@ def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: return output_tensor.movedim(0, dim).contiguous() def reduce_scatterv( - self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None + self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None ) -> torch.Tensor: raise NotImplementedError def gather( self, input_: torch.Tensor, dst: int = 0, dim: int = -1 - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """ NOTE: We assume that the input tensor is on the same device across all the ranks. @@ -231,7 +233,7 @@ def gather( output_tensor = None return output_tensor - def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + def send(self, tensor: torch.Tensor, dst: int | None = None) -> None: """Sends a tensor to the destination rank in a blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" if dst is None: @@ -239,7 +241,7 @@ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: torch.distributed.send(tensor, self.ranks[dst], self.device_group) def recv( - self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + self, size: torch.Size, dtype: torch.dtype, src: int | None = None ) -> torch.Tensor: """Receives a tensor from the source rank.""" """NOTE: `src` is the local rank of the source rank.""" diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py index c09b3ba9ceba..fdfb74d7a752 100644 --- a/vllm/distributed/device_communicators/cpu_communicator.py +++ b/vllm/distributed/device_communicators/cpu_communicator.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import Any, Optional, Union +from typing import Any import torch from torch.distributed import ProcessGroup @@ -18,8 +18,8 @@ class CpuCommunicator(DeviceCommunicatorBase): def __init__( self, cpu_group: ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[ProcessGroup] = None, + device: torch.device | None = None, + device_group: ProcessGroup | None = None, unique_name: str = "", ): super().__init__(cpu_group, device, device_group, unique_name) @@ -38,7 +38,7 @@ def all_reduce(self, input_): def gather( self, input_: torch.Tensor, dst: int = 0, dim: int = -1 - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """ NOTE: We assume that the input tensor is on the same device across all the ranks. @@ -99,7 +99,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: def send_tensor_dict( self, - tensor_dict: dict[str, Union[torch.Tensor, Any]], + tensor_dict: dict[str, torch.Tensor | Any], dst: int, ) -> None: return self.dist_module.send_tensor_dict(tensor_dict, dst) @@ -107,7 +107,7 @@ def send_tensor_dict( def recv_tensor_dict( self, src: int, - ) -> dict[str, Union[torch.Tensor, Any]]: + ) -> dict[str, torch.Tensor | Any]: return self.dist_module.recv_tensor_dict(src) @@ -140,16 +140,16 @@ def _init_cpu_shm(self) -> int: return handle def all_reduce( - self, input: torch.Tensor, group: Optional[ProcessGroup] = None + self, input: torch.Tensor, group: ProcessGroup | None = None ) -> None: torch.ops._C.shm_allreduce(self.handle, input) def gather( self, input: torch.Tensor, - gather_list: Optional[list[torch.Tensor]], + gather_list: list[torch.Tensor] | None, dst: int = -1, - group: Optional[ProcessGroup] = None, + group: ProcessGroup | None = None, ) -> None: # Note: different from the torch gather, here we use local dst rank. torch.ops._C.shm_gather( @@ -163,13 +163,13 @@ def all_gather_into_tensor( self, output: torch.Tensor, input: torch.Tensor, - group: Optional[ProcessGroup] = None, + group: ProcessGroup | None = None, ) -> None: torch.ops._C.shm_all_gather(self.handle, input, output) def send_tensor_dict( self, - tensor_dict: dict[str, Union[torch.Tensor, Any]], + tensor_dict: dict[str, torch.Tensor | Any], dst: int, ) -> None: key_list = list(tensor_dict.keys()) @@ -191,7 +191,7 @@ def send_tensor_dict( def recv_tensor_dict( self, src: int, - ) -> dict[str, Union[torch.Tensor, Any]]: + ) -> dict[str, torch.Tensor | Any]: tensor_list = torch.ops._C.shm_recv_tensor_list(self.handle, src) value_list: list[torch.Tensor] = tensor_list[:-1] diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 45096dffb5b6..2e878eef908a 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union import torch from torch.distributed import ProcessGroup @@ -26,8 +25,8 @@ class CudaCommunicator(DeviceCommunicatorBase): def __init__( self, cpu_group: ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[ProcessGroup] = None, + device: torch.device | None = None, + device_group: ProcessGroup | None = None, unique_name: str = "", ): super().__init__(cpu_group, device, device_group, unique_name) @@ -54,7 +53,7 @@ def __init__( ) from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator - self.pynccl_comm: Optional[PyNcclCommunicator] = None + self.pynccl_comm: PyNcclCommunicator | None = None if self.world_size > 1: self.pynccl_comm = PyNcclCommunicator( group=self.cpu_group, @@ -63,9 +62,9 @@ def __init__( if is_symmetric_memory_enabled(): register_nccl_symmetric_ops(self.pynccl_comm) - self.ca_comm: Optional[CustomAllreduce] = None - self.qr_comm: Optional[QuickAllReduce] = None - self.symm_mem_comm: Optional[SymmMemCommunicator] = None + self.ca_comm: CustomAllreduce | None = None + self.qr_comm: QuickAllReduce | None = None + self.symm_mem_comm: SymmMemCommunicator | None = None if use_torch_symm_mem and current_platform.is_cuda(): self.symm_mem_comm = SymmMemCommunicator( group=self.cpu_group, @@ -91,39 +90,38 @@ def __init__( self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device) if self.use_all2all: - all2all_backend = envs.VLLM_ALL2ALL_BACKEND - if all2all_backend == "naive": + if self.all2all_backend == "naive": from .all2all import NaiveAll2AllManager self.all2all_manager = NaiveAll2AllManager(self.cpu_group) - logger.info("Using naive all2all manager.") - elif all2all_backend == "allgather_reducescatter": + elif self.all2all_backend == "allgather_reducescatter": from .all2all import AgRsAll2AllManager self.all2all_manager = AgRsAll2AllManager(self.cpu_group) - logger.info("Using AllGather-ReduceScatter all2all manager.") - elif all2all_backend == "pplx": + elif self.all2all_backend == "pplx": from .all2all import PPLXAll2AllManager self.all2all_manager = PPLXAll2AllManager(self.cpu_group) - logger.info("Using PPLX all2all manager.") - elif all2all_backend == "deepep_high_throughput": + elif self.all2all_backend == "deepep_high_throughput": from .all2all import DeepEPHTAll2AllManager self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group) - logger.info("Using DeepEP High-Throughput all2all manager.") - elif all2all_backend == "deepep_low_latency": + elif self.all2all_backend == "deepep_low_latency": from .all2all import DeepEPLLAll2AllManager self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group) - logger.info("Using DeepEP Low-Latency all2all manager.") - elif all2all_backend == "flashinfer_all2allv": + elif self.all2all_backend == "flashinfer_all2allv": from .all2all import FlashInferAllToAllManager self.all2all_manager = FlashInferAllToAllManager(self.cpu_group) - logger.info("Using Flashinfer all2allv manager.") else: - raise ValueError(f"Unknown all2all backend: {all2all_backend}") + raise ValueError(f"Unknown all2all backend: {self.all2all_backend}") + + logger.info_once( + "Using %s all2all manager.", + self.all2all_manager.__class__.__name__, + scope="global", + ) def all_reduce(self, input_): # since currently we perform copy input -> symm_input -> out-of-place AR @@ -201,7 +199,7 @@ def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): return output.movedim(0, dim).contiguous() def reduce_scatterv( - self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None + self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None ): world_size = self.world_size pynccl_comm = self.pynccl_comm @@ -235,7 +233,7 @@ def reduce_scatterv( # Reshape before returning return output.movedim(0, dim).contiguous() - def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + def send(self, tensor: torch.Tensor, dst: int | None = None) -> None: """Sends a tensor to the destination rank in a blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" if dst is None: @@ -248,7 +246,7 @@ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: torch.distributed.send(tensor, self.ranks[dst], self.device_group) def recv( - self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + self, size: torch.Size, dtype: torch.dtype, src: int | None = None ) -> torch.Tensor: """Receives a tensor from the source rank.""" """NOTE: `src` is the local rank of the source rank.""" @@ -274,9 +272,9 @@ def destroy(self): def all_gatherv( self, - input_: Union[torch.Tensor, list[torch.Tensor]], + input_: torch.Tensor | list[torch.Tensor], dim: int = 0, - sizes: Optional[list[int]] = None, + sizes: list[int] | None = None, ): if dim != 0: raise NotImplementedError("only dim 0 all-gatherv is supported") @@ -289,7 +287,7 @@ def all_gatherv( if sizes is not None and all(s == sizes[0] for s in sizes): sizes = None - def _all_gather_single(input_: torch.Tensor, sizes: Optional[list[int]] = None): + def _all_gather_single(input_: torch.Tensor, sizes: list[int] | None = None): input_size = input_.size() if sizes is not None: assert len(sizes) == world_size diff --git a/vllm/distributed/device_communicators/cuda_wrapper.py b/vllm/distributed/device_communicators/cuda_wrapper.py index a77d2666e2ce..07ab2f712409 100644 --- a/vllm/distributed/device_communicators/cuda_wrapper.py +++ b/vllm/distributed/device_communicators/cuda_wrapper.py @@ -7,7 +7,7 @@ import ctypes from dataclasses import dataclass -from typing import Any, Optional +from typing import Any # this line makes it possible to directly load `libcudart.so` using `ctypes` import torch # noqa @@ -36,7 +36,7 @@ class Function: argtypes: list[Any] -def find_loaded_library(lib_name) -> Optional[str]: +def find_loaded_library(lib_name) -> str | None: """ According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, the file `/proc/self/maps` contains the memory maps of the process, which includes the @@ -113,7 +113,7 @@ class CudaRTLibrary: # to the corresponding dictionary path_to_dict_mapping: dict[str, dict[str, Any]] = {} - def __init__(self, so_file: Optional[str] = None): + def __init__(self, so_file: str | None = None): if so_file is None: so_file = find_loaded_library("libcudart") if so_file is None: diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index fd5c5dfd9da0..02591805a796 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager -from typing import Optional, Union +from typing import cast import torch import torch.distributed as dist @@ -17,7 +17,7 @@ from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless try: ops.meta_size() @@ -34,7 +34,7 @@ def _can_p2p(rank: int, world_size: int) -> bool: if i == rank: continue if envs.VLLM_SKIP_P2P_CHECK: - logger.info("Skipping P2P check and trusting the driver's P2P report.") + logger.debug("Skipping P2P check and trusting the driver's P2P report.") return torch.cuda.can_device_access_peer(rank, i) if not gpu_p2p_access_check(rank, i): return False @@ -55,7 +55,7 @@ class CustomAllreduce: def __init__( self, group: ProcessGroup, - device: Union[int, str, torch.device], + device: int | str | torch.device, max_size=8192 * 1024, symm_mem_enabled=False, ) -> None: @@ -119,15 +119,18 @@ def __init__( # now `device` is a `torch.device` object assert isinstance(device, torch.device) self.device = device - device_capability = current_platform.get_device_capability().as_version_str() + device_capability = current_platform.get_device_capability() if ( current_platform.is_cuda() and symm_mem_enabled - and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES + and device_capability is not None ): - max_size = min( - CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], max_size - ) + device_capability_str = device_capability.as_version_str() + if device_capability_str in CUSTOM_ALL_REDUCE_MAX_SIZES: + max_size = min( + CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability_str][world_size], + max_size, + ) cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES if cuda_visible_devices: device_ids = list(map(int, cuda_visible_devices.split(","))) @@ -214,6 +217,7 @@ def register_graph_buffers(self): # We cannot directly use `dist.all_gather_object` here # because it is incompatible with `gloo` backend under inference mode. # see https://github.com/pytorch/pytorch/issues/126032 for details. + all_data: list[list[list[int] | None]] all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))] all_data[self.rank] = [handle, offset] ranks = sorted(dist.get_process_group_ranks(group=self.group)) @@ -222,8 +226,8 @@ def register_graph_buffers(self): all_data[i], src=rank, group=self.group, device="cpu" ) # Unpack list of tuples to tuple of lists. - handles = [d[0] for d in all_data] # type: ignore - offsets = [d[1] for d in all_data] # type: ignore + handles = cast(list[list[int]], [d[0] for d in all_data]) + offsets = cast(list[list[int]], [d[1] for d in all_data]) ops.register_graph_buffers(self._ptr, handles, offsets) def should_custom_ar(self, inp: torch.Tensor): @@ -260,7 +264,7 @@ def all_reduce( ) return out - def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: + def custom_all_reduce(self, input: torch.Tensor) -> torch.Tensor | None: """The main allreduce API that provides support for cuda graph.""" # When custom allreduce is disabled, this will be None. if self.disabled or not self.should_custom_ar(input): @@ -292,8 +296,8 @@ def __del__(self): @staticmethod def create_shared_buffer( size_in_bytes: int, - group: Optional[ProcessGroup] = None, - uncached: Optional[bool] = False, + group: ProcessGroup | None = None, + uncached: bool | None = False, ) -> list[int]: pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes) @@ -313,8 +317,8 @@ def create_shared_buffer( @staticmethod def free_shared_buffer( pointers: list[int], - group: Optional[ProcessGroup] = None, - rank: Optional[int] = None, + group: ProcessGroup | None = None, + rank: int | None = None, ) -> None: if rank is None: rank = dist.get_rank(group=group) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 59fa3f9c449b..2fc35e80f591 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union # ===================== import region ===================== import torch @@ -20,7 +19,7 @@ ) from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.utils import current_stream +from vllm.utils.torch_utils import current_stream logger = init_logger(__name__) @@ -31,7 +30,7 @@ def register_nccl_symmetric_ops(pynccl_comm): from vllm.distributed.device_communicators.pynccl_allocator import ( nccl_symm_mem_context, ) - from vllm.utils import direct_register_custom_op + from vllm.utils.torch_utils import direct_register_custom_op global _NCCL_SYMM_OPS_REGISTERED if _NCCL_SYMM_OPS_REGISTERED: @@ -59,9 +58,9 @@ def all_reduce_symmetric_with_copy_fake(input_tensor: torch.Tensor) -> torch.Ten class PyNcclCommunicator: def __init__( self, - group: Union[ProcessGroup, StatelessProcessGroup], - device: Union[int, str, torch.device], - library_path: Optional[str] = None, + group: ProcessGroup | StatelessProcessGroup, + device: int | str | torch.device, + library_path: str | None = None, ): """ Args: @@ -106,11 +105,12 @@ def __init__( self.disabled = False self.nccl_version = self.nccl.ncclGetRawVersion() - logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) - if self.rank == 0: # get the unique id from NCCL self.unique_id = self.nccl.ncclGetUniqueId() + logger.info_once( + "vLLM is using nccl==%s", self.nccl.ncclGetVersion(), scope="local" + ) else: # construct an empty unique id self.unique_id = ncclUniqueId() diff --git a/vllm/distributed/device_communicators/pynccl_allocator.py b/vllm/distributed/device_communicators/pynccl_allocator.py index 3fe4fd744d77..401b80046f60 100644 --- a/vllm/distributed/device_communicators/pynccl_allocator.py +++ b/vllm/distributed/device_communicators/pynccl_allocator.py @@ -3,7 +3,7 @@ import atexit import contextlib import tempfile -from typing import Any, Optional +from typing import Any import torch from packaging import version @@ -14,7 +14,7 @@ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import find_nccl_include_paths +from vllm.utils.nccl import find_nccl_include_paths logger = init_logger(__name__) @@ -141,7 +141,7 @@ def __init__( or version.parse(torch.__version__) < version.parse("2.8.0.a0") ) if self.disabled: - self.pynccl_comm: Optional[PyNcclCommunicator] = None + self.pynccl_comm: PyNcclCommunicator | None = None self._mem_pool_ctx: contextlib.AbstractContextManager[Any] = ( contextlib.nullcontext() ) diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index e4d7b0f8fb85..b2433d58dc1f 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -25,7 +25,7 @@ import ctypes import platform from dataclasses import dataclass -from typing import Any, Optional +from typing import Any import torch from torch.distributed import ReduceOp @@ -33,7 +33,7 @@ from vllm import envs from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import find_nccl_library +from vllm.utils.nccl import find_nccl_library logger = init_logger(__name__) @@ -305,7 +305,7 @@ class NCCLLibrary: # to the corresponding dictionary path_to_dict_mapping: dict[str, dict[str, Any]] = {} - def __init__(self, so_file: Optional[str] = None): + def __init__(self, so_file: str | None = None): so_file = so_file or find_nccl_library() try: diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 16b6b6c28ea3..9c7765883cfd 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import Enum -from typing import Union import torch import torch.distributed as dist @@ -14,7 +13,7 @@ from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless logger = init_logger(__name__) @@ -58,9 +57,7 @@ class QuickAllReduce: (torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB], } - def __init__( - self, group: ProcessGroup, device: Union[int, str, torch.device] - ) -> None: + def __init__(self, group: ProcessGroup, device: int | str | torch.device) -> None: """ Custom allreduce provides non-destructive acceleration and is available for CUDA and ROCm MI300 series. diff --git a/vllm/distributed/device_communicators/ray_communicator.py b/vllm/distributed/device_communicators/ray_communicator.py index da79afc7ac14..d9517f51acad 100644 --- a/vllm/distributed/device_communicators/ray_communicator.py +++ b/vllm/distributed/device_communicators/ray_communicator.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import uuid -from typing import Any, Optional +from typing import Any import ray import torch @@ -14,7 +14,7 @@ ) from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger -from vllm.utils import current_stream +from vllm.utils.torch_utils import current_stream logger = init_logger(__name__) @@ -27,15 +27,15 @@ class RayPPCommunicator(Communicator): This class is not thread-safe. """ - _comm: Optional[DeviceCommunicatorBase] + _comm: DeviceCommunicatorBase | None def __init__( self, world_size: int, comm_id: Any, - rank: Optional[int], + rank: int | None, actor_handles: list["ray.actor.ActorHandle"], - cuda_stream: Optional[torch.cuda.Stream], + cuda_stream: torch.cuda.Stream | None, use_communication_streams: bool = False, ): """ @@ -56,7 +56,7 @@ def __init__( This is not supported. """ self._world_size = world_size - self._rank: Optional[int] = None + self._rank: int | None = None self._actor_handles = actor_handles if use_communication_streams: raise NotImplementedError("use_communication_streams is not supported") @@ -99,7 +99,7 @@ def _build_actor_rank_mapping(self): # Ray actor IDs are 32-character hex strings (128 bits) ACTOR_ID_LEN = 32 - actor_id_bytes = actor_id_str.encode("utf-8") + actor_id_bytes = bytearray(actor_id_str.encode("utf-8")) assert len(actor_id_bytes) == ACTOR_ID_LEN, ( f"Unexpected actor ID length: {len(actor_id_bytes)}" ) @@ -143,7 +143,7 @@ def get_rank(self, actor: ray.actor.ActorHandle) -> int: else: raise ValueError(f"Actor {actor} not found in communicator group") - def get_self_rank(self) -> Optional[int]: + def get_self_rank(self) -> int | None: """ Return this actor's rank. """ diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 4cec60102728..f92b3d34af0f 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import functools import pickle import time from contextlib import contextmanager from dataclasses import dataclass, field from multiprocessing import shared_memory +from pickle import PickleBuffer from threading import Event -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any from unittest.mock import patch import torch @@ -26,15 +27,25 @@ import vllm.envs as envs from vllm.distributed.utils import StatelessProcessGroup, sched_yield from vllm.logger import init_logger -from vllm.utils import ( +from vllm.utils.network_utils import ( get_ip, get_open_port, get_open_zmq_ipc_path, is_valid_ipv6_address, ) +if TYPE_CHECKING: + from _typeshed import SizedBuffer + VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL +from_bytes_big = functools.partial(int.from_bytes, byteorder="big") + + +def to_bytes_big(value: int, size: int) -> bytes: + return value.to_bytes(size, byteorder="big") + + logger = init_logger(__name__) @@ -80,7 +91,7 @@ def __init__( n_reader: int, max_chunk_bytes: int, max_chunks: int, - name: Optional[str] = None, + name: str | None = None, ): """ A shared memory ring buffer implementation for broadcast communication. @@ -213,9 +224,9 @@ def get_metadata(self, current_idx: int): class Handle: local_reader_ranks: list[int] = field(default_factory=list) - buffer_handle: Optional[tuple[int, int, int, str]] = None - local_subscribe_addr: Optional[str] = None - remote_subscribe_addr: Optional[str] = None + buffer_handle: tuple[int, int, int, str] | None = None + local_subscribe_addr: str | None = None + remote_subscribe_addr: str | None = None remote_addr_ipv6: bool = False @@ -224,10 +235,12 @@ def __init__( self, n_reader, # number of all readers n_local_reader, # number of local readers through shared memory - local_reader_ranks: Optional[list[int]] = None, - max_chunk_bytes: int = 1024 * 1024 * 10, + local_reader_ranks: list[int] | None = None, + # Default of 24MiB chosen to be large enough to accommodate grammar + # bitmask tensors for large batches (1024 requests). + max_chunk_bytes: int = 1024 * 1024 * 24, max_chunks: int = 10, - connect_ip: Optional[str] = None, + connect_ip: str | None = None, ): if local_reader_ranks is None: local_reader_ranks = list(range(n_local_reader)) @@ -299,7 +312,7 @@ def __init__( remote_addr_ipv6=remote_addr_ipv6, ) - logger.info("vLLM message queue communication handle: %s", self.handle) + logger.debug("vLLM message queue communication handle: %s", self.handle) def export_handle(self) -> Handle: return self.handle @@ -384,7 +397,7 @@ def wait_until_ready(self): assert recv == b"READY" @contextmanager - def acquire_write(self, timeout: Optional[float] = None): + def acquire_write(self, timeout: float | None = None): assert self._is_writer, "Only writers can acquire write" start_time = time.monotonic() n_warning = 1 @@ -444,8 +457,8 @@ def acquire_write(self, timeout: Optional[float] = None): @contextmanager def acquire_read( self, - timeout: Optional[float] = None, - cancel: Optional[Event] = None, + timeout: float | None = None, + cancel: Event | None = None, indefinite: bool = False, ): assert self._is_local_reader, "Only readers can acquire read" @@ -502,26 +515,53 @@ def acquire_read( self._read_spin_timer.record_activity() break - def enqueue(self, obj, timeout: Optional[float] = None): + def enqueue(self, obj, timeout: float | None = None): """Write to message queue with optional timeout (in seconds)""" assert self._is_writer, "Only writers can enqueue" - serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + all_buffers: list[SizedBuffer] = [b""] + total_bytes = 6 # 2 bytes for oob buffer count, 4 for main buffer size + + def oob_callback(buf: PickleBuffer) -> bool: + raw_buf = buf.raw() + if len(raw_buf) < 1024 * 1024: + # In-line buffers smaller than 1MiB. + return True + all_buffers.append(raw_buf) + nonlocal total_bytes + total_bytes += len(raw_buf) + 4 + return False + + all_buffers[0] = pickle.dumps( + obj, protocol=pickle.HIGHEST_PROTOCOL, buffer_callback=oob_callback + ) if self.n_local_reader > 0: - if len(serialized_obj) >= self.buffer.max_chunk_bytes: + if total_bytes + len(all_buffers[0]) >= self.buffer.max_chunk_bytes: with self.acquire_write(timeout) as buf: buf[0] = 1 # overflow - self.local_socket.send(serialized_obj) + self.local_socket.send_multipart(all_buffers, copy=False) else: + # Byte 0: 0 + # Bytes 1-2: Count of buffers + # Then each buffer follows, preceded by 4 bytes containing its length: + # [4 byte int L][L bytes of buffer content] ... with self.acquire_write(timeout) as buf: buf[0] = 0 # not overflow - buf[1 : len(serialized_obj) + 1] = serialized_obj + offset = 3 + buf[1:offset] = to_bytes_big(len(all_buffers), 2) # oob buf count + for buffer in all_buffers: + buf_len = len(buffer) + # prepend each buffer with 4 bytes containing its size. + buf_offset = offset + 4 + buf[offset:buf_offset] = to_bytes_big(buf_len, 4) + buf[buf_offset : (offset := buf_offset + buf_len)] = buffer + if self.n_remote_reader > 0: - self.remote_socket.send(serialized_obj) + self.remote_socket.send_multipart(all_buffers, copy=False) def dequeue( self, - timeout: Optional[float] = None, - cancel: Optional[Event] = None, + timeout: float | None = None, + cancel: Event | None = None, indefinite: bool = False, ): """Read from message queue with optional timeout (in seconds)""" @@ -529,10 +569,15 @@ def dequeue( with self.acquire_read(timeout, cancel, indefinite) as buf: overflow = buf[0] == 1 if not overflow: - # no need to know the size of serialized object - # pickle format contains the size information internally - # see https://docs.python.org/3/library/pickle.html - obj = pickle.loads(buf[1:]) + offset = 3 + buf_count = from_bytes_big(buf[1:offset]) + all_buffers = [] + for i in range(buf_count): + buf_offset = offset + 4 + buf_len = from_bytes_big(buf[offset:buf_offset]) + offset = buf_offset + buf_len + all_buffers.append(buf[buf_offset:offset]) + obj = pickle.loads(all_buffers[0], buffers=all_buffers[1:]) if overflow: obj = MessageQueue.recv(self.local_socket, timeout) elif self._is_remote_reader: @@ -542,23 +587,22 @@ def dequeue( return obj @staticmethod - def recv(socket: zmq.Socket, timeout: Optional[float]) -> Any: + def recv(socket: zmq.Socket, timeout: float | None) -> Any: timeout_ms = None if timeout is None else int(timeout * 1000) if not socket.poll(timeout=timeout_ms): raise TimeoutError - recv = socket.recv(copy=False) - return pickle.loads(recv.buffer) + recv, *recv_oob = socket.recv_multipart(copy=False) + return pickle.loads(recv, buffers=recv_oob) def broadcast_object(self, obj=None): if self._is_writer: self.enqueue(obj) return obj - else: - return self.dequeue() + return self.dequeue() @staticmethod def create_from_process_group( - pg: Union[ProcessGroup, StatelessProcessGroup], + pg: ProcessGroup | StatelessProcessGroup, max_chunk_bytes, max_chunks, writer_rank=0, diff --git a/vllm/distributed/device_communicators/shm_object_storage.py b/vllm/distributed/device_communicators/shm_object_storage.py index a5486c30edf2..080bc03e3913 100644 --- a/vllm/distributed/device_communicators/shm_object_storage.py +++ b/vllm/distributed/device_communicators/shm_object_storage.py @@ -3,13 +3,13 @@ import pickle from abc import ABC, abstractmethod -from collections.abc import Iterable +from collections.abc import Callable, Iterable from contextlib import contextmanager from dataclasses import dataclass from itertools import chain from multiprocessing import shared_memory from multiprocessing.synchronize import Lock as LockType -from typing import Any, Callable, Optional, Union +from typing import Any from unittest.mock import patch import torch @@ -109,7 +109,7 @@ class SingleWriterShmRingBuffer: def __init__( self, data_buffer_size: int, - name: Optional[str] = None, + name: str | None = None, create: bool = False, ): self.data_buffer_size = data_buffer_size @@ -252,7 +252,7 @@ def access_buf(self, address: int): def free_buf( self, is_free_fn: Callable[[int, memoryview], bool], - nbytes: Optional[int] = None, + nbytes: int | None = None, ) -> Iterable[int]: """ Free a buffer of the given size. This is a no-op in shared memory, @@ -340,9 +340,7 @@ def __init__(self): self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem) self._mm_kwargs_item_cls = MultiModalKwargsItem - def serialize( - self, value: Any - ) -> tuple[Union[bytes, list[bytes]], int, bytes, int]: + def serialize(self, value: Any) -> tuple[bytes | list[bytes], int, bytes, int]: len_arr = None if isinstance(value, (torch.Tensor, self._mm_kwargs_item_cls)): type_name = type(value).__name__ @@ -396,7 +394,7 @@ class ShmObjectStorageHandle: n_readers: int ring_buffer_handle: tuple[int, str] serde_class: type[ObjectSerde] - reader_lock: Optional[LockType] + reader_lock: LockType | None class SingleWriterShmObjectStorage: @@ -444,7 +442,7 @@ def __init__( n_readers: int, ring_buffer: SingleWriterShmRingBuffer, serde_class: type[ObjectSerde] = MsgpackSerde, - reader_lock: Optional[LockType] = None, + reader_lock: LockType | None = None, ): """ Initialize the object storage. @@ -492,7 +490,7 @@ def clear(self) -> None: def copy_to_buffer( self, - data: Union[bytes, list[bytes]], + data: bytes | list[bytes], data_bytes: int, metadata: bytes, md_bytes: int, diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py index 88451f9552c1..74d6fb40c83b 100644 --- a/vllm/distributed/device_communicators/symm_mem.py +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union import torch import torch.distributed as dist @@ -10,6 +9,9 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES, ) from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) from vllm.platforms import current_platform try: @@ -31,10 +33,10 @@ class SymmMemCommunicator: def __init__( self, group: ProcessGroup, - device: Union[int, str, torch.device], + device: int | str | torch.device, # add options for testing - force_multimem: Optional[bool] = None, - max_size_override: Optional[int] = None, + force_multimem: bool | None = None, + max_size_override: int | None = None, ): self.disabled = True @@ -53,9 +55,14 @@ def __init__( self.device = device self.group = group self.world_size = dist.get_world_size(self.group) - self.device_capability = ( - current_platform.get_device_capability().as_version_str() - ) + capability = current_platform.get_device_capability() + if capability is None: + logger.warning( + "SymmMemCommunicator: device capability is unknown, " + "communicator is not available." + ) + return + self.device_capability = capability.as_version_str() if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES: logger.warning( "SymmMemCommunicator: Device capability %s not supported, " @@ -96,6 +103,8 @@ def __init__( return self.force_multimem = force_multimem self.disabled = False + if vllm_is_batch_invariant(): + self.disabled = True def should_use_symm_mem(self, inp: torch.Tensor): if self.disabled: @@ -108,8 +117,8 @@ def should_use_symm_mem(self, inp: torch.Tensor): return inp_size < self.max_size def all_reduce( - self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None - ) -> Optional[torch.Tensor]: + self, inp: torch.Tensor, *, out: torch.Tensor | None = None + ) -> torch.Tensor | None: if not self.should_use_symm_mem(inp): return None if out is None: diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index b2faea512791..a7724a86cc6a 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import Optional import torch from torch.distributed import ProcessGroup @@ -32,15 +31,15 @@ ) if USE_RAY: - from vllm.executor import ray_utils + from vllm.v1.executor import ray_utils class TpuCommunicator(DeviceCommunicatorBase): def __init__( self, cpu_group: ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[ProcessGroup] = None, + device: torch.device | None = None, + device_group: ProcessGroup | None = None, unique_name: str = "", ): super().__init__(cpu_group, device, device_group, unique_name) diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index 33d5b2cf1d87..ad61fdfb8ea5 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -1,13 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch import torch.distributed as dist from torch.distributed import ProcessGroup -import vllm.envs as envs from vllm.logger import init_logger from .base_device_communicator import DeviceCommunicatorBase @@ -19,21 +17,20 @@ class XpuCommunicator(DeviceCommunicatorBase): def __init__( self, cpu_group: ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[ProcessGroup] = None, + device: torch.device | None = None, + device_group: ProcessGroup | None = None, unique_name: str = "", ): super().__init__(cpu_group, device, device_group, unique_name) if self.use_all2all: - all2all_backend = envs.VLLM_ALL2ALL_BACKEND - if all2all_backend != "naive": + if self.all2all_backend != "naive": logger.warning( - "`%s` all2all manager is not supported on XPU." + "`%s` all2all manager is not supported on XPU. " "Falling back to `naive` all2all manager for XPU.", - all2all_backend, + self.all2all_backend, ) - all2all_backend = "naive" - if all2all_backend == "naive": + self.all2all_backend = "naive" + if self.all2all_backend == "naive": from .all2all import NaiveAll2AllManager self.all2all_manager = NaiveAll2AllManager(self.cpu_group) @@ -45,7 +42,7 @@ def all_reduce(self, input_) -> torch.Tensor: def gather( self, input_: torch.Tensor, dst: int = 0, dim: int = -1 - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: assert -input_.dim() <= dim < input_.dim(), ( f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" ) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 663f04027046..17716e8a07ac 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -29,7 +29,6 @@ import time from collections.abc import Sequence from dataclasses import dataclass -from typing import Optional, Union import torch from torch.distributed import ProcessGroup, all_reduce @@ -186,9 +185,9 @@ def build( model: MixtureOfExperts, device: torch.device, parallel_config: ParallelConfig, - global_expert_load: Optional[torch.Tensor] = None, - old_global_expert_indices: Optional[torch.Tensor] = None, - rank_mapping: Optional[dict[int, int]] = None, + global_expert_load: torch.Tensor | None = None, + old_global_expert_indices: torch.Tensor | None = None, + rank_mapping: dict[int, int] | None = None, ) -> "EplbState": """ Build the initial EPLB state. @@ -439,9 +438,9 @@ def rearrange( model: MixtureOfExperts, is_profile: bool = False, execute_shuffle: bool = True, - global_expert_load: Optional[torch.Tensor] = None, - rank_mapping: Optional[dict[int, int]] = None, - ) -> Optional[torch.Tensor]: + global_expert_load: torch.Tensor | None = None, + rank_mapping: dict[int, int] | None = None, + ) -> torch.Tensor | None: """ Rearrange the experts according to the current load. """ @@ -611,7 +610,7 @@ def recv_state() -> tuple[torch.Tensor, torch.Tensor]: def _node_count_with_rank_mapping( - pg: Union[ProcessGroup, StatelessProcessGroup], + pg: ProcessGroup | StatelessProcessGroup, rank_mapping: dict[int, int], ) -> int: if isinstance(pg, ProcessGroup): diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 344fae457c9b..f8ec3e956401 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -8,7 +8,6 @@ from collections.abc import Iterable, MutableSequence, Sequence from functools import partial -from typing import Optional import torch from torch.distributed import ( @@ -253,7 +252,7 @@ def rearrange_expert_weights_inplace( expert_weights: Sequence[Iterable[torch.Tensor]], ep_group: ProcessGroup, is_profile: bool = False, - rank_mapping: Optional[dict[int, int]] = None, + rank_mapping: dict[int, int] | None = None, ) -> None: """ Rearranges the expert weights in place according to the new expert indices. diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index d93ae63e0eb4..7b5cb94cf13e 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -6,10 +6,11 @@ import time from abc import ABC, abstractmethod from collections import deque +from collections.abc import Callable from dataclasses import asdict from itertools import count from queue import Queue -from typing import Any, Callable, Optional, Union +from typing import Any import msgspec import zmq @@ -29,7 +30,7 @@ class EventBatch( ): ts: float events: list[Any] - data_parallel_rank: Optional[int] = None + data_parallel_rank: int | None = None class KVCacheEvent( @@ -47,16 +48,16 @@ class KVCacheEvent( class BlockStored(KVCacheEvent): block_hashes: list[ExternalBlockHash] - parent_block_hash: Optional[ExternalBlockHash] + parent_block_hash: ExternalBlockHash | None token_ids: list[int] block_size: int - lora_id: Optional[int] - medium: Optional[str] + lora_id: int | None + medium: str | None class BlockRemoved(KVCacheEvent): block_hashes: list[ExternalBlockHash] - medium: Optional[str] + medium: str | None class AllBlocksCleared(KVCacheEvent): @@ -64,7 +65,7 @@ class AllBlocksCleared(KVCacheEvent): class KVEventBatch(EventBatch): - events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]] + events: list[BlockStored | BlockRemoved | AllBlocksCleared] class EventPublisher(ABC): @@ -116,7 +117,7 @@ class ZmqEventPublisher(EventPublisher): Parameters ---------- endpoint: - PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to + PUB address. Use `tcp://*:5557` to bind or `tcp://host:5557` to connect. replay_endpoint: Optional ROUTER address for replay requests. When given, subscribers can @@ -139,7 +140,7 @@ def __init__( self, data_parallel_rank: int, endpoint: str = "tcp://*:5557", - replay_endpoint: Optional[str] = None, + replay_endpoint: str | None = None, buffer_steps: int = 10_000, hwm: int = 100_000, max_queue_size: int = 100_000, @@ -147,13 +148,13 @@ def __init__( ) -> None: # Storage super().__init__(data_parallel_rank) - self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size) + self._event_queue = Queue[EventBatch | None](maxsize=max_queue_size) self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps) # ZMQ sockets self._ctx = zmq.Context.instance() - self._pub: Optional[zmq.Socket] = None - self._replay: Optional[zmq.Socket] = None + self._pub: zmq.Socket | None = None + self._replay: zmq.Socket | None = None self._dp_rank = data_parallel_rank self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank) @@ -303,8 +304,8 @@ def _service_replay(self) -> None: @staticmethod def offset_endpoint_port( - endpoint: Optional[str], data_parallel_rank: int - ) -> Optional[str]: + endpoint: str | None, data_parallel_rank: int + ) -> str | None: """Helper function to offset the port in an endpoint by the data parallel rank. @@ -349,15 +350,19 @@ def register_publisher(cls, name: str, ctor: Callable[..., EventPublisher]) -> N @classmethod def create( - cls, config: Optional[KVEventsConfig], data_parallel_rank: int = 0 + cls, config: KVEventsConfig | None, data_parallel_rank: int = 0 ) -> EventPublisher: """Create publisher from a config mapping.""" - if not config: + if ( + config is None + or not config.enable_kv_cache_events + or config.publisher == "null" + ): return NullEventPublisher() config_dict = asdict(config) - kind = config_dict.pop("publisher", "null") + kind = config_dict.pop("publisher") config_dict.pop("enable_kv_cache_events") try: constructor = cls._registry[kind] diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 395a4e20e0ba..c64996f13cd5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -2,18 +2,22 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib -from typing import TYPE_CHECKING, Callable +from collections.abc import Callable +from typing import TYPE_CHECKING, cast import vllm.envs as envs +from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.base import ( KVConnectorBase, KVConnectorBaseType, ) -from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + KVConnectorRole, + supports_hma, +) from vllm.logger import init_logger if TYPE_CHECKING: - from vllm.config import VllmConfig from vllm.config.kv_transfer import KVTransferConfig logger = init_logger(__name__) @@ -37,7 +41,7 @@ def loader() -> type[KVConnectorBase]: @classmethod def create_connector( cls, - config: "VllmConfig", + config: VllmConfig, role: KVConnectorRole, ) -> KVConnectorBase: if not envs.VLLM_USE_V1: @@ -47,7 +51,18 @@ def create_connector( ) kv_transfer_config = config.kv_transfer_config + if kv_transfer_config is None: + raise ValueError("kv_transfer_config must be set to create a connector") connector_cls = cls.get_connector_class(kv_transfer_config) + + # check if the connector supports HMA + hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager + if hma_enabled and not supports_hma(connector_cls): + raise ValueError( + f"Connector {connector_cls.__name__} does not support HMA but " + f"HMA is enabled. Please set `--disable-hybrid-kv-cache-manager`." + ) + logger.info( "Creating v1 connector with name: %s and engine_id: %s", connector_cls.__name__, @@ -63,12 +78,32 @@ def create_connector( # We build separately to enforce strict separation return connector_cls(config, role) + @classmethod + def get_connector_class_by_name( + cls, connector_name: str + ) -> type[KVConnectorBaseType]: + """Get a registered connector class by name. + + Raises ValueError if the connector is not registered. + + Args: + connector_name: Name of the registered connector. + + Returns: + The connector class. + """ + if connector_name not in cls._registry: + raise ValueError(f"Connector '{connector_name}' is not registered.") + return cls._registry[connector_name]() + @classmethod def get_connector_class( cls, kv_transfer_config: "KVTransferConfig" ) -> type[KVConnectorBaseType]: """Get the connector class by name.""" connector_name = kv_transfer_config.kv_connector + if connector_name is None: + raise ValueError("Connector name is not set in KVTransferConfig") if connector_name in cls._registry: connector_cls = cls._registry[connector_name]() else: @@ -76,7 +111,13 @@ def get_connector_class( if connector_module_path is None: raise ValueError(f"Unsupported connector type: {connector_name}") connector_module = importlib.import_module(connector_module_path) - connector_cls = getattr(connector_module, connector_name) + try: + connector_cls = getattr(connector_module, connector_name) + except AttributeError as e: + raise AttributeError( + f"Class {connector_name} not found in {connector_module_path}" + ) from e + connector_cls = cast(type[KVConnectorBaseType], connector_cls) return connector_cls @@ -119,3 +160,9 @@ def get_connector_class( "vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector", "OffloadingConnector", ) + +KVConnectorFactory.register_connector( + "DecodeBenchConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector", + "DecodeBenchConnector", +) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 056ece60e84d..22af489a89b9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -4,10 +4,9 @@ KV cache helper for store. """ -from collections import defaultdict from collections.abc import Sequence from concurrent.futures import CancelledError, Future -from typing import Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Literal, cast import torch @@ -18,6 +17,9 @@ from vllm.logger import init_logger from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput +if TYPE_CHECKING: + from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase + logger = init_logger(__name__) @@ -124,11 +126,16 @@ class KVOutputAggregator: """Utility class to aggregate the output of all workers into a single output corresponding to Rank 0 for scheduler.""" - def __init__(self, world_size: int): + def __init__(self, expected_finished_count: int): # Complete transfer tracker. Used to track finished requests # [req_id -> n_remaining_workers] - self._recv_remaining_count = defaultdict[str, int](lambda: world_size) - self._send_remaining_count = defaultdict[str, int](lambda: world_size) + self._recv_remaining_count = dict[str, int]() + self._send_remaining_count = dict[str, int]() + self._expected_finished_count = expected_finished_count + + @classmethod + def from_connector(cls, connector: "KVConnectorBase", world_size: int): + return cls(connector.get_finished_count() or world_size) def aggregate( self, outputs: list[ModelRunnerOutput], output_rank: int = 0 @@ -136,12 +143,15 @@ def aggregate( # Aggregate kv_connector_output from all workers def update_finished_set( - req_ids: Optional[set[str]], + req_ids: set[str] | None, remaining_count_dict: dict[str, int], finished_set: set[str], ) -> None: for req_id in req_ids or (): - remaining_count_dict[req_id] -= 1 + remaining_count = remaining_count_dict.get( + req_id, self._expected_finished_count + ) + remaining_count_dict[req_id] = remaining_count - 1 if remaining_count_dict[req_id] == 0: finished_set.add(req_id) del remaining_count_dict[req_id] @@ -151,21 +161,34 @@ def update_finished_set( aggregated_kv_connector_stats = None invalid_block_ids = set[int]() for model_runner_output in outputs: - output = model_runner_output.kv_connector_output - if not output: + kv_output = model_runner_output.kv_connector_output + if not kv_output: continue + # Allow the worker to dynamically update the expected number of + # finished sending/recving for new requests. + if ( + kv_output.expected_finished_count > 0 + and kv_output.expected_finished_count != self._expected_finished_count + ): + logger.debug( + "Expected finished requests updated from %d to %d", + self._expected_finished_count, + kv_output.expected_finished_count, + ) + self._expected_finished_count = kv_output.expected_finished_count + update_finished_set( - output.finished_sending, self._send_remaining_count, finished_sending + kv_output.finished_sending, self._send_remaining_count, finished_sending ) update_finished_set( - output.finished_recving, self._recv_remaining_count, finished_recving + kv_output.finished_recving, self._recv_remaining_count, finished_recving ) # Aggregate kv_connector_stats from all workers. if aggregated_kv_connector_stats is None: # Use the first worker's kv_connector_stats as accumulator. - aggregated_kv_connector_stats = output.kv_connector_stats - elif kv_connector_stats := output.kv_connector_stats: + aggregated_kv_connector_stats = kv_output.kv_connector_stats + elif kv_connector_stats := kv_output.kv_connector_stats: if aggregated_kv_connector_stats is None: aggregated_kv_connector_stats = kv_connector_stats else: @@ -176,7 +199,7 @@ def update_finished_set( aggregated_kv_connector_stats.aggregate(kv_connector_stats) ) - invalid_block_ids |= output.invalid_block_ids + invalid_block_ids |= kv_output.invalid_block_ids # select output of the worker specified by output_rank output = outputs[output_rank] @@ -186,6 +209,7 @@ def update_finished_set( finished_recving=finished_recving or None, kv_connector_stats=aggregated_kv_connector_stats or None, invalid_block_ids=invalid_block_ids, + expected_finished_count=self._expected_finished_count, ) return output @@ -197,7 +221,7 @@ def async_aggregate( to the respective list of outputs.""" result_future: Future[ModelRunnerOutput] = Future() - outputs: list[Optional[ModelRunnerOutput]] = [None] * len(output_futures) + outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures) def make_callback(idx): def callback(fut): @@ -230,8 +254,8 @@ def callback(fut): def _make_src_and_dst_indices( src_block_ids: list[int], dst_block_ids: list[int], - src_device: Union[torch.device, str], - dst_device: Union[torch.device, str], + src_device: torch.device | str, + dst_device: torch.device | str, ) -> tuple[torch.Tensor, torch.Tensor]: src_indices = torch.tensor(src_block_ids, device=src_device, dtype=torch.int64) dst_indices = torch.tensor(dst_block_ids, device=dst_device, dtype=torch.int64) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py index 034c7afe97a4..0e16bc5cc685 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py @@ -3,6 +3,17 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorRole, + SupportsHMA, + supports_hma, +) +from vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector import ( # noqa E:501 + DecodeBenchConnector, ) -__all__ = ["KVConnectorRole", "KVConnectorBase_V1"] +__all__ = [ + "KVConnectorRole", + "KVConnectorBase_V1", + "supports_hma", + "SupportsHMA", + "DecodeBenchConnector", +] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 70225e95aed2..2562eb9ce70e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -14,11 +14,12 @@ temporary buffer alloc by the CacheManager. update_connector_output() - update KVConnector state after output is received from worker-side connectors. - request_finished() - called when a request is finished, with - the computed kv cache blocks for the request. - Returns whether KV cache should be freed now or will be - freed asynchronously and optionally returns KV transfer - params. + request_finished() - called once when a request is finished, + with the computed kv cache blocks for the request. + Returns whether KV cache should be freed now or if the + connector now assumes responsibility for freeing the + the blocks asynchronously. Also optionally returns KV + transfer params. take_events() - returns new KV events that were collected by the connector since the last call. @@ -36,8 +37,8 @@ import enum from abc import ABC, abstractmethod -from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional +from collections.abc import Callable, Iterable +from typing import TYPE_CHECKING, Any, Literal, Optional import torch @@ -69,6 +70,45 @@ logger = init_logger(__name__) +class SupportsHMA(ABC): + """ + The class that indicates the corresponding connector supports hybrid memory + allocator (HMA). + This is required to use the connector together with hybrid memory allocator. + """ + + @abstractmethod + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + """ + Called exactly once when a request has finished for all kv cache groups, + before its blocks are freed for each group. + + NOTE(Kuntai): This function is only supported by connectors that support HMA. + + The connector may assumes responsibility for freeing the blocks + asynchronously by returning True. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + raise NotImplementedError + + +def supports_hma(connector: Any) -> bool: + if isinstance(connector, type): + return issubclass(connector, SupportsHMA) + else: + return isinstance(connector, SupportsHMA) + + class KVConnectorRole(enum.Enum): # Connector running in the scheduler process SCHEDULER = 0 @@ -92,8 +132,12 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): "Initializing KVConnectorBase_V1. This API is experimental and " "subject to change in the future as we iterate the design." ) - self._connector_metadata: Optional[KVConnectorMetadata] = None + self._connector_metadata: KVConnectorMetadata | None = None self._vllm_config = vllm_config + if vllm_config.kv_transfer_config is not None: + self._kv_transfer_config = vllm_config.kv_transfer_config + else: + raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1") self._role = role @property @@ -221,7 +265,7 @@ def wait_for_save(self): def get_finished( self, finished_req_ids: set[str] - ) -> tuple[Optional[set[str]], Optional[set[str]]]: + ) -> tuple[set[str] | None, set[str] | None]: """ Notifies worker-side connector ids of requests that have finished generating tokens on the worker. @@ -280,7 +324,7 @@ def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, - ) -> tuple[Optional[int], bool]: + ) -> tuple[int | None, bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. @@ -360,9 +404,13 @@ def request_finished( self, request: "Request", block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: """ - Called when a request has finished, before its blocks are freed. + Called exactly once when a request has finished, before its blocks are + freed. + + The connector may assumes responsibility for freeing the blocks + asynchronously by returning True. Returns: True if the request is being saved/sent asynchronously and blocks @@ -383,7 +431,7 @@ def take_events(self) -> Iterable["KVCacheEvent"]: return () @classmethod - def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> Optional[str]: + def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None: """ Get the required KV cache layout for this connector. Args: @@ -401,10 +449,11 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> Optional[str] ) return None - def get_finished_count(self) -> Optional[int]: + def get_finished_count(self) -> int | None: """ Get the count of requests expected to complete send/receive operations - via this connector. + via this connector. This method is used to initialize the + KVOutputAggregator, overwriting the default world_size. Returns: int: expected sending or receiving completion count. @@ -414,7 +463,7 @@ def get_finished_count(self) -> Optional[int]: @classmethod def build_kv_connector_stats( - cls, data: Optional[dict[str, Any]] = None + cls, data: dict[str, Any] | None = None ) -> Optional["KVConnectorStats"]: """ KVConnectorStats resolution method. This method allows dynamically diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py new file mode 100644 index 000000000000..17c00b9c3d0e --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py @@ -0,0 +1,413 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +DecodeBenchConnector: A KV Connector for decode instance performance testing. + +This connector emulates a prefill-decode disaggregated setting by filling +the KV cache with dummy values, allowing measurement of decoder performance +under larger input sequence lengths (ISL) in resource-limited environments. + +Usage: + To use this connector for benchmarking, configure it in the kv_transfer_config: + + Example: + vllm serve <model> --kv-transfer-config '{ + "kv_connector": "DecodeBenchConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "fill_mean": 0.015, + "fill_std": 0.0 + } + }' + + Then run your benchmark with desired input/output lengths: + vllm bench serve --base-url http://127.0.0.1:8000 --model <model> \\ + --dataset-name random --random-input-len 40000 \\ + --random-output-len 100 --max-concurrency 10 + + Configuration options (via kv_connector_extra_config): + - fill_mean (float): Mean value for random normal fill (default: 0.015) + - fill_std (float): Standard deviation for random fill (default: 0.0) + Set to 0 for constant values, >0 for random sampling +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import torch + +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + KVConnectorBase_V1, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata +from vllm.logger import init_logger +from vllm.utils import cdiv + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.config import VllmConfig + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class DecodeBenchConnectorMetadata(KVConnectorMetadata): + """Metadata for DecodeBenchConnector. + + Contains information about which requests need their KV cache filled + with dummy values for benchmarking purposes. + """ + + # request_id -> (block_ids_per_group, num_tokens_to_fill) + # block_ids_per_group is a tuple of lists, one per KV cache group + # For standard attention: single group, e.g., ([1, 2, 3],) + # For MLA: multiple groups, e.g., ([1, 2], [1, 2]) + reqs_to_fill: dict[str, tuple[tuple[list[int], ...], int]] + + +class DecodeBenchConnector(KVConnectorBase_V1): + """ + A KV Connector for decode instance performance testing. + + This connector fills the KV cache with dummy (non-zero) values to + emulate a prefill-decode disaggregated setting, enabling performance + testing of the decoder with larger input sequence lengths. + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config, role) + + self.connector_scheduler: DecodeBenchConnectorScheduler | None = None + self.connector_worker: DecodeBenchConnectorWorker | None = None + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = DecodeBenchConnectorScheduler(vllm_config) + elif role == KVConnectorRole.WORKER: + self.connector_worker = DecodeBenchConnectorWorker(vllm_config) + + # ============================== + # Worker-side methods + # ============================== + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, DecodeBenchConnectorMetadata) + self.connector_worker.start_fill_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + # All operations are synchronous, so nothing to wait for + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs: Any, + ) -> None: + # This connector doesn't save KV cache (benchmarking only) + pass + + def wait_for_save(self): + # This connector doesn't save KV cache (benchmarking only) + pass + + # ============================== + # Scheduler-side methods + # ============================== + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int | None, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens + ) + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens + ) + + def build_connector_meta( + self, scheduler_output: "SchedulerOutput" + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + assert self.connector_scheduler is not None + self.connector_scheduler.request_finished(request) + return False, None + + +class DecodeBenchConnectorScheduler: + """Scheduler-side implementation for DecodeBenchConnector.""" + + def __init__(self, vllm_config: "VllmConfig"): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + + # Track which requests have already been filled + self._filled_requests: set[str] = set() + + # Track pending fills for the current scheduler step + # request_id -> (block_ids_per_group, num_tokens_to_fill) + # Note: _pending_fills doesn't need explicit cleanup - it's cleared + # after build_connector_meta() is called in the same scheduler step + self._pending_fills: dict[str, tuple[tuple[list[int], ...], int]] = {} + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + For new requests, return the number of tokens that should be filled + with dummy KV cache values. + + Returns: + (num_tokens_to_fill, is_async) + - num_tokens_to_fill: number of uncomputed tokens minus 1 + (we fill everything except the last token for decode) + - is_async: False (synchronous filling) + """ + req_id = request.request_id + + # Only fill once per request on first scheduling + if req_id in self._filled_requests: + return 0, False + + # Calculate how many tokens we need to fill + # Fill all uncomputed tokens except the last one (which will be decoded) + # This simulates having processed a long prefill + num_uncomputed_tokens = request.num_tokens - num_computed_tokens + num_tokens_to_fill = max(0, num_uncomputed_tokens - 1) + + if num_tokens_to_fill == 0: + return 0, False + + # Return False for synchronous operation - the fill is fast enough + # that async overhead isn't worth it + return num_tokens_to_fill, False + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + """ + Called after blocks are allocated. Store the block IDs so we can + fill them with dummy values. + + Supports both standard attention (single KV cache group) and MLA + (multiple KV cache groups). + """ + req_id = request.request_id + + if num_external_tokens == 0: + return + + # Get the block IDs that were allocated + # block_groups is a tuple of lists, one per KV cache group + # For standard attention: 1 group + # For MLA: multiple groups (one per attention type) + block_groups = blocks.get_block_ids() + + # Calculate how many blocks we need to fill + # num_external_tokens are the tokens we said we'd provide + num_blocks_to_fill = cdiv(num_external_tokens, self.block_size) + + # Extract the first num_blocks_to_fill blocks from each group + # All groups should have the same block IDs for the same request + block_ids_per_group = tuple( + group_blocks[:num_blocks_to_fill] for group_blocks in block_groups + ) + + # Store the blocks to fill for all group. _pending_fills doesn't need cleanup + # as it's cleared after build_connector_meta + self._pending_fills[req_id] = ( + block_ids_per_group, + num_external_tokens, + ) + self._filled_requests.add(req_id) + + logger.debug( + "DecodeBenchConnector: Allocated %d blocks across %d KV cache groups " + "for request %s", + num_blocks_to_fill, + len(block_groups), + req_id, + ) + + def build_connector_meta( + self, scheduler_output: "SchedulerOutput" + ) -> KVConnectorMetadata: + """ + Build metadata containing information about which blocks to fill + with dummy KV values. + """ + meta = DecodeBenchConnectorMetadata(reqs_to_fill=self._pending_fills.copy()) + + # Clear pending fills after building metadata + self._pending_fills.clear() + + return meta + + def request_finished(self, request: "Request"): + """ + Called when a request has finished. Clean up any state. + """ + self._filled_requests.discard(request.request_id) + + +class DecodeBenchConnectorWorker: + """Worker-side implementation for DecodeBenchConnector.""" + + def __init__(self, vllm_config: "VllmConfig"): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + + # Get fill parameters from extra config + kv_transfer_config = vllm_config.kv_transfer_config + assert kv_transfer_config is not None + self.fill_mean = kv_transfer_config.get_from_extra_config("fill_mean", 0.015) + self.fill_std = kv_transfer_config.get_from_extra_config("fill_std", 0.0) + + # Will be populated via register_kv_caches + self.kv_caches: dict[str, torch.Tensor] | None = None + + # Mapping from KV cache group index to list of layer names in that group + self.group_to_layers: dict[int, list[str]] | None = None + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Store references to the KV cache tensors and build group mapping.""" + self.kv_caches = kv_caches + + # For simplicity, assume all layers belong to group 0 (standard attention) + # For MLA models with multiple groups, the metadata will handle the mapping + # We just need to fill the blocks specified in the metadata + self.group_to_layers = {0: list(kv_caches.keys())} + + logger.debug( + "DecodeBenchConnector: Registered %d KV cache layers", + len(kv_caches), + ) + + def start_fill_kv(self, metadata: DecodeBenchConnectorMetadata): + """ + Fill the allocated KV cache blocks with dummy (non-zero) values. + + This simulates having a populated KV cache from a prefill phase, + allowing decode performance testing with larger context sizes. + + Supports both standard attention (single group) and MLA (multiple groups). + """ + if not metadata.reqs_to_fill: + return + + assert self.kv_caches is not None, "KV caches must be registered before filling" + assert self.group_to_layers is not None, "Group mapping must be initialized" + + for req_id, (block_ids_per_group, num_tokens) in metadata.reqs_to_fill.items(): + # Fill blocks for each KV cache group + for group_idx, block_ids in enumerate(block_ids_per_group): + self._fill_blocks(group_idx, block_ids, num_tokens) + + logger.debug( + "DecodeBenchConnector: Filled %d blocks (%d tokens) across %d groups " + "for request %s", + len(block_ids_per_group[0]) if block_ids_per_group else 0, + num_tokens, + len(block_ids_per_group), + req_id, + ) + + def _fill_blocks(self, group_idx: int, block_ids: list[int], num_tokens: int): + """ + Fill specified blocks with dummy non-zero values for a specific KV cache group. + + Args: + group_idx: The KV cache group index to fill + block_ids: List of block IDs to fill in this group + num_tokens: Total number of tokens to fill across these blocks + """ + if not block_ids: + return + + assert self.kv_caches is not None + assert self.group_to_layers is not None + + # Get the layers that belong to this group + layer_names = self.group_to_layers.get(group_idx, []) + + # Fill only the layers in this group + for layer_name in layer_names: + if layer_name not in self.kv_caches: + logger.warning( + "DecodeBenchConnector: Layer %s not found in KV caches", layer_name + ) + continue + + kv_cache = self.kv_caches[layer_name] + + # Convert block_ids to tensor on device + block_ids_tensor = torch.tensor( + block_ids, dtype=torch.long, device=kv_cache.device + ) + + # Filter invalid block IDs + valid_mask = block_ids_tensor < kv_cache.shape[0] + valid_block_ids = block_ids_tensor[valid_mask] + + if len(valid_block_ids) == 0: + continue + + # Create fill values - either constant or random + block_shape = kv_cache.shape[1:] + if self.fill_std > 0: + # Random normal sampling + fill_values = torch.normal( + mean=self.fill_mean, + std=self.fill_std, + size=(len(valid_block_ids),) + block_shape, + dtype=kv_cache.dtype, + device=kv_cache.device, + ) + else: + # Constant fill value + fill_values = torch.full( + (len(valid_block_ids),) + block_shape, + self.fill_mean, + dtype=kv_cache.dtype, + device=kv_cache.device, + ) + + # Batch fill operation + kv_cache[valid_block_ids] = fill_values + + logger.debug( + "DecodeBenchConnector: Filled %d blocks in group %d with %s values " + "(mean=%.3f, std=%.3f)", + len(block_ids), + group_idx, + "random" if self.fill_std > 0 else "constant", + self.fill_mean, + self.fill_std, + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index b50cc3ab30fa..a5240adab438 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import torch -from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl +from lmcache.integration.vllm.vllm_v1_adapter import ( + LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl, +) from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -11,6 +13,9 @@ KVConnectorMetadata, KVConnectorRole, ) +from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import ( + vllm_v1_adapter as _adapter, +) from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput @@ -26,7 +31,18 @@ class LMCacheConnectorV1(KVConnectorBase_V1): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) - self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self) + assert vllm_config.kv_transfer_config is not None + use_native = vllm_config.kv_transfer_config.get_from_extra_config( + "use_native", False + ) + if use_native: + logger.info("Initializing native LMCache connector") + cls = _adapter.LMCacheConnectorV1Impl + else: + logger.info("Initializing latest dev LMCache connector") + cls = LMCacheConnectorLatestImpl + + self._lmcache_engine = cls(vllm_config, role, self) # ============================== # Worker-side methods @@ -96,7 +112,7 @@ def wait_for_save(self): def get_finished( self, finished_req_ids: set[str] - ) -> tuple[Optional[set[str]], Optional[set[str]]]: + ) -> tuple[set[str] | None, set[str] | None]: """ Notifies worker-side connector ids of requests that have finished generating tokens. @@ -117,7 +133,7 @@ def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, - ) -> tuple[Optional[int], bool]: + ) -> tuple[int | None, bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. @@ -161,7 +177,7 @@ def request_finished( self, request: "Request", block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: """ Called when a request has finished, before its blocks are freed. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py new file mode 100644 index 000000000000..208f01a7cb5e --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py new file mode 100644 index 000000000000..e0282c155248 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Standard +import os +import threading +from typing import TYPE_CHECKING, Union + +import torch +from lmcache.config import LMCacheEngineConfig as Config +from lmcache.logging import init_logger +from lmcache.v1.config import LMCacheEngineConfig as V1Config + +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.multimodal.inputs import PlaceholderRange + from vllm.v1.core.sched.output import NewRequestData + from vllm.v1.request import Request + +logger = init_logger(__name__) +ENGINE_NAME = "vllm-instance" + +# Thread-safe singleton storage +_config_instance: Config | V1Config | None = None +_config_lock = threading.Lock() + + +def is_false(value: str) -> bool: + """Check if the given string value is equivalent to 'false'.""" + return value.lower() in ("false", "0", "no", "n", "off") + + +def lmcache_get_or_create_config() -> Config | V1Config: + """Get the LMCache configuration from the environment variable + `LMCACHE_CONFIG_FILE`. If the environment variable is not set, this + function will return the default configuration. + + This function is thread-safe and implements singleton pattern, + ensuring the configuration is loaded only once. + """ + global _config_instance + + # Double-checked locking for thread-safe singleton + if _config_instance is None: + with _config_lock: + if _config_instance is None: # Check again within lock + if is_false(os.getenv("LMCACHE_USE_EXPERIMENTAL", "True")): + logger.warning( + "Detected LMCACHE_USE_EXPERIMENTAL is set to False. " + "Using legacy configuration is deprecated and will " + "be remove soon! Please set LMCACHE_USE_EXPERIMENTAL " + "to True." + ) + LMCacheEngineConfig = Config # type: ignore[assignment] + else: + LMCacheEngineConfig = V1Config # type: ignore[assignment] + + if "LMCACHE_CONFIG_FILE" not in os.environ: + logger.warning( + "No LMCache configuration file is set. Trying to read" + " configurations from the environment variables." + ) + logger.warning( + "You can set the configuration file through " + "the environment variable: LMCACHE_CONFIG_FILE" + ) + _config_instance = LMCacheEngineConfig.from_env() + else: + config_file = os.environ["LMCACHE_CONFIG_FILE"] + logger.info("Loading LMCache config file %s", config_file) + _config_instance = LMCacheEngineConfig.from_file(config_file) + # Update config from environment variables + _config_instance.update_config_from_env() + return _config_instance + + +def hex_hash_to_int16(s: str) -> int: + """ + Convert a hex hash string to a 16-bit integer. + """ + return int(s, 16) & 0xFFFF + + +def apply_mm_hashes_to_token_ids( + token_ids: torch.Tensor, + mm_hashes: list[str], + mm_positions: list["PlaceholderRange"], +) -> torch.Tensor: + """ + Overwrite token_ids in-place for multimodal placeholders using + efficient slice assignments. + """ + n = token_ids.size(0) + for hash_str, placeholder in zip(mm_hashes, mm_positions): + start, length = placeholder.offset, placeholder.length + if start >= n: + continue + end = min(start + length, n) + token_ids[start:end] = hex_hash_to_int16(hash_str) + return token_ids + + +def mla_enabled(model_config: "ModelConfig") -> bool: + return ( + hasattr(model_config, "use_mla") + and isinstance(model_config.use_mla, bool) + and model_config.use_mla + ) + + +def create_lmcache_metadata( + vllm_config=None, model_config=None, parallel_config=None, cache_config=None +): + """ + Create LMCacheEngineMetadata from vLLM configuration. + + This function extracts common metadata creation logic that was duplicated + across multiple files. + + Args: + vllm_config (VllmConfig): vLLM configuration object containing model, + parallel, and cache configs (alternative to + individual config parameters) + model_config (ModelConfig): Model configuration (alternative to + vllm_config) + parallel_config (ParallelConfig): Parallel configuration (alternative + to vllm_config) + cache_config (CacheConfig): Cache configuration (alternative to + vllm_config) + """ + # Third Party + # First Party + from lmcache.config import LMCacheEngineMetadata + + from vllm.utils import get_kv_cache_torch_dtype + + config = lmcache_get_or_create_config() + # Support both vllm_config object and individual config parameters + if vllm_config is not None: + model_cfg = vllm_config.model_config + parallel_cfg = vllm_config.parallel_config + cache_cfg = vllm_config.cache_config + else: + if model_config is None or parallel_config is None or cache_config is None: + raise ValueError( + "Either vllm_config must be provided, or all of " + "model_config, parallel_config, and cache_config must be provided." + ) + model_cfg = model_config + parallel_cfg = parallel_config + cache_cfg = cache_config + + # Get KV cache dtype + kv_dtype = get_kv_cache_torch_dtype(cache_cfg.cache_dtype, model_cfg.dtype) + + # Check if MLA is enabled + use_mla = mla_enabled(model_cfg) + + # Construct KV shape (for memory pool) + num_layer = model_cfg.get_num_layers(parallel_cfg) + chunk_size = config.chunk_size + num_kv_head = model_cfg.get_num_kv_heads(parallel_cfg) + head_size = model_cfg.get_head_size() + kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size) + + # Create metadata + metadata = LMCacheEngineMetadata( + model_cfg.model, + parallel_cfg.world_size, + parallel_cfg.rank, + "vllm", + kv_dtype, + kv_shape, + use_mla, + ) + + return metadata, config + + +def extract_mm_features( + request: Union["Request", "NewRequestData"], modify: bool = False +) -> tuple[list[str], list["PlaceholderRange"]]: + """ + Normalize multimodal information from a Request into parallel lists. + + This helper reads either: + 1) `request.mm_features` (objects each exposing `.identifier` and + `.mm_position`), or + 2) legacy fields `request.mm_hashes` and `request.mm_positions`. + + It returns two equally sized lists: the multimodal hash identifiers and + their corresponding positions. If the request contains no multimodal info, + it returns `([], [])`. + + Args: + request (Request): The source object. + modify (bool): + Controls copy semantics for the legacy-path return values. + - If True and legacy fields are used, shallow-copies are returned so + the caller can mutate the lists without affecting `request`. + - If False, the original legacy sequences are returned as-is + (zero-copy); treat them as read-only. + + Returns: + tuple[list[str], list[PlaceholderRange]]: (`mm_hashes`, `mm_positions`). + May be `([], [])` when no multimodal data is present. + """ + if getattr(request, "mm_features", None): + mm_hashes, mm_positions = zip( + *((f.identifier, f.mm_position) for f in request.mm_features) + ) + return (list(mm_hashes), list(mm_positions)) + elif getattr(request, "mm_hashes", None): + if modify: + return ( + request.mm_hashes.copy(), # type: ignore + request.mm_positions.copy(), # type: ignore + ) + else: + return (request.mm_hashes, request.mm_positions) # type: ignore + else: + return ([], []) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py new file mode 100644 index 000000000000..1f42b598bc9c --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py @@ -0,0 +1,1396 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Standard +import os +import uuid +from collections.abc import Generator +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Optional + +import torch +from lmcache import utils +from lmcache.config import LMCacheEngineMetadata +from lmcache.logging import init_logger +from lmcache.observability import LMCStatsMonitor +from lmcache.utils import _lmcache_nvtx_annotate +from lmcache.v1.cache_engine import LMCacheEngine, LMCacheEngineBuilder +from lmcache.v1.compute.blend import LMCBlenderBuilder +from lmcache.v1.config import LMCacheEngineConfig, _validate_and_set_config_value +from lmcache.v1.gpu_connector import ( + VLLMBufferLayerwiseGPUConnector, + VLLMPagedMemGPUConnectorV2, + VLLMPagedMemLayerwiseGPUConnector, +) +from lmcache.v1.internal_api_server.api_server import InternalAPIServer +from lmcache.v1.lookup_client import LookupClientFactory +from lmcache.v1.lookup_client.lmcache_async_lookup_client import ( + LMCacheAsyncLookupServer, +) +from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer +from lmcache.v1.plugin.plugin_launcher import PluginLauncher + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import ( + ENGINE_NAME, + apply_mm_hashes_to_token_ids, + extract_mm_features, + lmcache_get_or_create_config, + mla_enabled, +) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_tp_group +from vllm.sampling_params import SamplingParams +from vllm.utils import cdiv, get_kv_cache_torch_dtype +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.version import __version__ as VLLM_VERSION + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.multimodal.inputs import PlaceholderRange + from vllm.v1.core.kv_cache_manager import KVCacheManager + from vllm.v1.core.sched.output import NewRequestData + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class LoadSpec: + # Number of tokens cached in vLLM + vllm_cached_tokens: int + # Number of tokens that are cached in LMCache + lmcache_cached_tokens: int + # Whether the scheduler allow us to load the tokens + can_load: bool + + +@dataclass +class SaveSpec: + # Skip already saved tokens + skip_leading_tokens: int + # Whether the scheduler allow us to save the tokens + can_save: bool + + +@dataclass +class DisaggSpec: + req_id: str + receiver_id: str + receiver_host: str + receiver_init_port: int + receiver_alloc_port: int + is_last_prefill: bool = False + num_transferred_tokens: int = 0 + + +tmp_disagg_tracker: dict[str, DisaggSpec] = {} + + +def extract_request_configs(sampling_params: SamplingParams) -> dict | None: + request_configs = None + if ( + sampling_params.extra_args is not None + and "kv_transfer_params" in sampling_params.extra_args + ): + kv_transfer_params = sampling_params.extra_args.get("kv_transfer_params") + if kv_transfer_params is None: + return None + assert isinstance(kv_transfer_params, dict) + for k, v in kv_transfer_params.items(): + if k.startswith("lmcache."): + if request_configs is None: + request_configs = {} + request_configs[k] = v + return request_configs + + +@dataclass +class RequestTracker: + # Request id + req_id: str + + # Total prompt token length + prompt_len: int + + # The token ids that has been scheduled so far + token_ids: list[int] + + # The block ids that has been allocated so far + # NOTE: allocated blocks could be more than the number of tokens + allocated_block_ids: list[int] + + # The number of tokens that has been saved + num_saved_tokens: int = 0 + + # Disagg spec for the request + disagg_spec: DisaggSpec | None = None + + # Multimodal hashes and positions + mm_hashes: list[str] | None = None + mm_positions: list["PlaceholderRange"] | None = None + + # The configs of the request, includes tags and other configs + request_configs: dict | None = None + + # Whether the request is in decode phase + is_decode_phase = False + + # Whether the request cache should be saved + skip_save: bool = False + + @_lmcache_nvtx_annotate + @staticmethod + def from_new_request( + lmcache_config: LMCacheEngineConfig, + new_request: "NewRequestData", + num_tokens_to_compute: int, + lmcache_cached_tokens: int, + skip_save: bool, + ) -> "RequestTracker": + """Create the request tracker from a new request. + + Args: + lmcache_config (LMCacheEngineConfig): the LMCache engine config. + new_request (NewRequestData): the new request data. + num_tokens_to_compute (int): the number of tokens that will + be 'computed', including the `num_computed_tokens` (vLLM's + local cache hit) and new tokens that will be scheduled. + lmcache_cached_tokens (int): the number of tokens that are + cached in LMCache. + skip_save (bool): whether the request cache should be saved + """ + # vLLM 0.9.0 update: request.block_ids changed from list[int] to + # list[list[int]] + # Need to check the type of request.block_ids + + unfolded_block_ids = [] + + if not isinstance(new_request.block_ids[0], list): + unfolded_block_ids = new_request.block_ids.copy() + else: + # According to the vLLM code + # (https://github.com/vllm-project/vllm/blob/main/vllm/v1/core/ + # sched/scheduler.py#L943), + # only one KVCacheGroup is supported in connector for now. + unfolded_block_ids = new_request.block_ids[0].copy() + + # NOTE: Initialized in `update_state_after_alloc` + disagg_spec = tmp_disagg_tracker.pop(new_request.req_id, None) + + if new_request.sampling_params: + request_configs = extract_request_configs(new_request.sampling_params) + else: + request_configs = None + + mm_hashes, mm_positions = extract_mm_features(new_request, modify=True) + + assert new_request.prompt_token_ids is not None + return RequestTracker( + req_id=new_request.req_id, + prompt_len=len(new_request.prompt_token_ids), + token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].copy(), + allocated_block_ids=unfolded_block_ids, + num_saved_tokens=lmcache_cached_tokens, + disagg_spec=disagg_spec, + mm_hashes=mm_hashes, + mm_positions=mm_positions, + skip_save=skip_save, + request_configs=request_configs, + ) + + def update( + self, + new_token_ids: list[int], + new_block_ids: tuple[list[int], ...] | None | list[int], + ) -> None: + """Update the request tracker when a running request is + scheduled again + """ + + self.token_ids.extend(new_token_ids) + + if new_block_ids is None: + # https://github.com/vllm-project/vllm/commit/ + # b029de9902aa3ac58806c8c17776c7074175b6db + new_block_ids = [] + elif len(new_block_ids) == 0: + new_block_ids = [] + elif isinstance(new_block_ids, tuple): + new_block_ids = new_block_ids[0] + elif isinstance(new_block_ids, list): + pass + else: + raise ValueError(f"Unsupported new_block_ids type {type(new_block_ids)}") + self.allocated_block_ids.extend(new_block_ids) + + # When a request is scheduled again, and the number of new tokens + # is 1 (excluding chunked prefill), the request is in decode phase. + if len(new_token_ids) == 1: + self.is_decode_phase = True + + +@dataclass +class ReqMeta: + # Request id + req_id: str + # Request tokens + token_ids: list[int] # torch.Tensor + # Slot mapping + slot_mapping: torch.Tensor + + # Whether is last prefill or not + is_last_prefill: bool = False + + # Skip save or not + save_spec: SaveSpec | None = None + # load_spec + load_spec: LoadSpec | None = None + # disagg spec + disagg_spec: DisaggSpec | None = None + # the configs of the request + request_configs: dict | None = None + + @staticmethod + def from_request_tracker( + tracker: RequestTracker, + block_size: int, + lmcache_chunk_size: int = 256, + load_spec: LoadSpec | None = None, + discard_partial_chunks: bool = True, + save_decode_cache: bool = False, + ) -> Optional["ReqMeta"]: + """Create the request metadata from a request tracker. + + Args: + tracker (RequestTracker): the request tracker. + block_size (int): the block size in vLLM. + lmcache_chunk_size (int): the chunk size for LMCache. + load_spec (Optional[LoadSpec]): the load spec for KV cache loading. + discard_partial_chunks (bool): whether to discard partial chunks. + save_decode_cache (bool): whether to save the cache in decode phase. + + Returns: + the request metadata if we need to perform load/save + operations, None otherwise. + """ + input_token_ids = tracker.token_ids + input_token_len = len(input_token_ids) + + is_last_prefill = False + if input_token_len == tracker.prompt_len: + is_last_prefill = True + + # For save operation: do not save if the following condition is met + # 1. has already been saved before (num_saved_tokens > 0) + # 2. number of unsaved tokens is not reached the chunk boundary + # 3. if save_decode_cache is False and it is in decode phase + + skip_leading_tokens = tracker.num_saved_tokens + chunk_boundary = ( + cdiv(tracker.num_saved_tokens + 1, lmcache_chunk_size) * lmcache_chunk_size + ) + + # NOTE(vladnosiv): for disagg, you cannot skip saving, as saving is a + # trqansfer. Check if request_configs has lmcache.skip_save set to True + request_skip = (tracker.request_configs or {}).get("lmcache.skip_save", False) + + skip_save = tracker.disagg_spec is None and ( + tracker.skip_save + or (tracker.num_saved_tokens > 0 and input_token_len < chunk_boundary) + or (tracker.is_decode_phase and not save_decode_cache) + or request_skip + ) + + if skip_save and load_spec is None: + return None + + # Calculate number of tokens to save based on discard_partial_chunks + # setting + + # NOTE(vladnosiv): for the input_token_len chunk prefill, + # we are required to discard partial chunks, + # as new tokens will be added in the next iteration. + num_tokens_to_save = ( + (input_token_len // lmcache_chunk_size * lmcache_chunk_size) + if not is_last_prefill or discard_partial_chunks + else input_token_len + ) + + # If we need to save, update the number of saved tokens + if not skip_save: + tracker.num_saved_tokens = num_tokens_to_save + save_spec = SaveSpec(skip_leading_tokens, not skip_save) + + # Calculate the token ids and slot mappings for load and save + token_ids = input_token_ids[:num_tokens_to_save] + + # If the request has multimodal hashes, apply them to the token ids + if tracker.mm_hashes: + token_ids_tensor = torch.tensor(token_ids) + assert tracker.mm_positions is not None, ( + "tracker got mm_hashes but no mm_positions" + ) + apply_mm_hashes_to_token_ids( + token_ids_tensor, tracker.mm_hashes, tracker.mm_positions + ) + token_ids = token_ids_tensor.tolist() + + num_blocks = len(tracker.allocated_block_ids) + + if len(token_ids) > num_blocks * block_size: + logger.error( + "The number of tokens is more than the number of blocks." + "Something might be wrong in scheduling logic!" + ) + logger.error( + "Num tokens: %d, num blocks: %d, block size: %d", + len(token_ids), + num_blocks, + block_size, + ) + + block_ids = torch.tensor(tracker.allocated_block_ids, dtype=torch.long) + block_offsets = torch.arange(0, block_size, dtype=torch.long) + slot_mapping = ( + block_offsets.reshape((1, block_size)) + + block_ids.reshape((num_blocks, 1)) * block_size + ) + + slot_mapping = slot_mapping.flatten()[: len(token_ids)] + assert slot_mapping.dtype == torch.long + + # For load operation: check whether the request is scheduled to load + if load_spec is not None and load_spec.can_load: + logger.debug( + "Scheduled to load %d tokens for request %s", + load_spec.lmcache_cached_tokens, + tracker.req_id, + ) + else: + # Do not load if not in `can_load` state + load_spec = None + + return ReqMeta( + req_id=tracker.req_id, + token_ids=token_ids, + slot_mapping=slot_mapping, + is_last_prefill=is_last_prefill, + save_spec=save_spec, + load_spec=load_spec, + disagg_spec=tracker.disagg_spec, + request_configs=tracker.request_configs, + ) + + +def need_gpu_interm_buffer(lmcache_config: LMCacheEngineConfig): + return lmcache_config.enable_pd + + +def _calculate_mtp_layers(vllm_config, model_config): + num_mtp_layers = 0 + if vllm_config is not None and vllm_config.speculative_config is not None: + logger.info( + "vllm_config.speculative_config: %s", vllm_config.speculative_config + ) + # TODO(baoloongmao): Support other MTP methods + if vllm_config.speculative_config.method == "deepseek_mtp": + num_mtp_layers = getattr( + model_config.hf_config, "num_nextn_predict_layers", 0 + ) + return num_mtp_layers + + +def _init_lmcache_engine( + lmcache_config: LMCacheEngineConfig, + vllm_config: "VllmConfig", +) -> LMCacheEngine: + """Initialize the LMCache engine by the given model config and parallel + config. This function will check the environment variable + `LMCACHE_CONFIG_FILE` to load the configuration file. If that environment + variable is not set, this function will return None. + + :param lmcache_config: The LMCache configuration. + :type lmcache_config: LMCacheEngineConfig + :param vllm_config: The vLLM configuration. + :type vllm_config: VllmConfig + + :return: The initialized LMCache engine + :rtype: LMCacheEngine + """ + if curr_engine := LMCacheEngineBuilder.get(ENGINE_NAME): + return curr_engine + + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + cache_config = vllm_config.cache_config + + assert isinstance(lmcache_config, LMCacheEngineConfig), ( + "LMCache v1 configuration is should be passed." + ) + + kv_dtype = get_kv_cache_torch_dtype(cache_config.cache_dtype, model_config.dtype) + + use_mla = mla_enabled(model_config) + if use_mla and ( + lmcache_config.remote_serde != "naive" + and lmcache_config.remote_serde is not None + ): + raise ValueError("MLA only works with naive serde mode..") + + # construct kv shape (for mem pool) + num_layer = model_config.get_num_layers(parallel_config) + num_mtp_layers = _calculate_mtp_layers(vllm_config, model_config) + num_layer += num_mtp_layers + chunk_size = lmcache_config.chunk_size + num_kv_head = model_config.get_num_kv_heads(parallel_config) + head_size = model_config.get_head_size() + kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size) + logger.info( + "use mla: %s, kv shape: %s, num_mtp_layers: %s", + use_mla, + kv_shape, + num_mtp_layers, + ) + + # Change current device. + num_gpus = torch.cuda.device_count() + local_rank = parallel_config.rank % num_gpus + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + metadata = LMCacheEngineMetadata( + model_config.model, + parallel_config.world_size, + parallel_config.rank, + "vllm", + kv_dtype, + kv_shape, + use_mla, + ) + + use_gpu = need_gpu_interm_buffer(lmcache_config) + vllm_gpu_connector: ( + VLLMBufferLayerwiseGPUConnector + | VLLMPagedMemGPUConnectorV2 + | VLLMPagedMemLayerwiseGPUConnector + ) + + if use_mla and lmcache_config.use_layerwise: + raise ValueError("layerwise MLA connector is not supported yet") + + # When use_mla is True, num_kv_head is 1 + hidden_dim_size = num_kv_head * head_size + if lmcache_config.use_layerwise: + if lmcache_config.enable_blending: + # Use layerwise connector for blending + vllm_gpu_connector = VLLMBufferLayerwiseGPUConnector( + hidden_dim_size, + num_layer, + use_gpu=use_gpu, + chunk_size=chunk_size, + dtype=kv_dtype, + device=device, + ) + else: + vllm_gpu_connector = VLLMPagedMemLayerwiseGPUConnector( + hidden_dim_size, + num_layer, + use_gpu=use_gpu, + chunk_size=chunk_size, + dtype=kv_dtype, + device=device, + ) + else: + vllm_gpu_connector = VLLMPagedMemGPUConnectorV2( + hidden_dim_size, + num_layer, + use_gpu=use_gpu, + chunk_size=chunk_size, + dtype=kv_dtype, + device=device, + use_mla=use_mla, + ) + tpg = get_tp_group() + engine = LMCacheEngineBuilder.get_or_create( + ENGINE_NAME, + lmcache_config, + metadata, + vllm_gpu_connector, + tpg.broadcast, + tpg.broadcast_object, + ) + + return engine + + +@dataclass +class LMCacheConnectorMetadata(KVConnectorMetadata): + requests: list[ReqMeta] = field(default_factory=list) + lookup_requests_in_step: list[str] = field(default_factory=list) + + @_lmcache_nvtx_annotate + def add_request(self, req_meta: ReqMeta) -> None: + """Add a request to the metadata. + + Args: + req_meta (ReqMeta): the request metadata. + """ + self.requests.append(req_meta) + + +class LMCacheConnectorV1Impl: + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + parent: KVConnectorBase_V1, + ): + assert vllm_config.kv_transfer_config is not None + self._parent = parent + self._vllm_config = vllm_config + self.kv_role = vllm_config.kv_transfer_config.kv_role + self.worker_count = vllm_config.parallel_config.tensor_parallel_size + config = lmcache_get_or_create_config() + assert isinstance(config, LMCacheEngineConfig), ( + "LMCache v1 configuration is should be passed for vLLM v1." + ) + # Put the leading with "lmcache." and matched configs from + # vllm extra_config to the config + kv_connector_extra_config = ( + vllm_config.kv_transfer_config.kv_connector_extra_config + ) + if kv_connector_extra_config: + for key, value in kv_connector_extra_config.items(): + if key.startswith("lmcache."): + config_key = key[8:] # Remove "lmcache." prefix + if _validate_and_set_config_value(config, config_key, value): + logger.info( + "Updated config %s from vLLM extra config: %s", + config_key, + value, + ) + + self.config = config + + self.async_loading = config.enable_async_loading + self.layerwise_retrievers: list[Generator[torch.Tensor | None, None, None]] = [] + self._stats_monitor = LMCStatsMonitor.GetOrCreate() + if role == KVConnectorRole.SCHEDULER: + # Create lookup client using factory + self.lookup_client = LookupClientFactory.create_lookup_client( + vllm_config, config + ) + self._unfinished_requests: dict[str, Request] = {} + self._lookup_requests_in_step: list[str] = [] + self.lmcache_engine = None + else: + self.lmcache_engine = _init_lmcache_engine( + config, + vllm_config, + ) + + self.use_layerwise = config.use_layerwise + self.enable_blending = config.enable_blending + + if self.enable_blending: + self.blender = LMCBlenderBuilder.get_or_create( + ENGINE_NAME, + self.lmcache_engine, + self.lmcache_engine.gpu_connector, + config, + ) + + # Create lookup server using factory + assert self.lmcache_engine is not None + self.lookup_server = LookupClientFactory.create_lookup_server( + self.lmcache_engine, vllm_config + ) + + self.offload_server = ZMQOffloadServer( + self.lmcache_engine, + vllm_config, + get_tensor_model_parallel_rank(), + ) + + # In case of MLA, the lookup server is only created on worker 0 + if self.async_loading and self.lookup_server is not None: + assert isinstance(self.lookup_server, LMCacheAsyncLookupServer) + self.lmcache_engine.post_init(async_lookup_server=self.lookup_server) + + self.kv_caches: dict[str, torch.Tensor] = {} + + self._block_size = vllm_config.cache_config.block_size + + # request_id -> (vllm cached tokens, lmcache cached tokens) + self.load_specs: dict[str, LoadSpec] = {} + + self.kv_cache_manager: KVCacheManager | None = None + + # request_id -> full_token_ids + self._request_trackers: dict[str, RequestTracker] = {} + + # Whether to discard partial chunks + self._discard_partial_chunks = ( + vllm_config.kv_transfer_config.get_from_extra_config( + "discard_partial_chunks", False + ) + or not config.save_unfull_chunk + ) + + self._lmcache_chunk_size = config.chunk_size + self._save_decode_cache = config.save_decode_cache + + self.skip_last_n_tokens = vllm_config.kv_transfer_config.get_from_extra_config( + "skip_last_n_tokens", 0 + ) + + self.num_layers = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config + ) + self.current_layer = 0 + + self.force_skip_save = bool(os.environ.get("LMCACHE_FORCE_SKIP_SAVE", False)) + + self._requests_priority: dict[str, int] = {} + + # TODO(baoloongmao): Internal api server & plugin framework support + # dp > 1 + if ( + vllm_config.parallel_config.data_parallel_size_local == 1 + or vllm_config.parallel_config.data_parallel_rank_local == 0 + ): + # Start internal API server if enabled + # The enabled check is in the InternalAPIServer constructor + self.api_server = InternalAPIServer(self) + self.api_server.start() + # Launch plugins + self.plugin_launcher = PluginLauncher( + self.config, + role, + self.worker_count, + -1 + if self.lmcache_engine is None # scheduler side + else self.lmcache_engine.metadata.worker_id, + ) + self.plugin_launcher.launch_plugins() + else: + self.api_server = None # type: ignore[assignment] + self.plugin_launcher = None # type: ignore[assignment] + logger.info( + "LMCache initialized for role %s with version %s, " + "vllm version %s, lmcache cache_engine metadata: %s", + role, + utils.get_version(), + VLLM_VERSION, + getattr(self.lmcache_engine, "metadata", None), + ) + + def get_inference_info(self) -> dict: + """Get inference information including vLLM config and related details. + + Returns: + dict: Dictionary containing inference information + """ + # Get vLLM config information + vllm_config = self._vllm_config + + # Use vLLM config's string representation and add specific configs + inference_info = { + "vllm_version": VLLM_VERSION, + "lmcache_version": utils.get_version(), + "vllm_config": str(vllm_config), + "model_config": { + "model": getattr(vllm_config.model_config, "model", None), + "dtype": str(getattr(vllm_config.model_config, "dtype", None)), + "max_model_len": getattr( + vllm_config.model_config, "max_model_len", None + ), + "vocab_size": getattr(vllm_config.model_config, "vocab_size", None), + "num_layers": getattr( + vllm_config.model_config, "get_num_layers", lambda _: None + )(vllm_config.parallel_config), + "num_attention_heads": getattr( + vllm_config.model_config, "get_num_attention_heads", lambda _: None + )(vllm_config.parallel_config), + "num_kv_heads": getattr( + vllm_config.model_config, "get_num_kv_heads", lambda _: None + )(vllm_config.parallel_config), + "head_size": getattr( + vllm_config.model_config, "get_head_size", lambda: None + )(), + }, + "cache_config": { + "block_size": getattr(vllm_config.cache_config, "block_size", None), + "cache_dtype": str( + getattr(vllm_config.cache_config, "cache_dtype", None) + ), + "gpu_memory_utilization": getattr( + vllm_config.cache_config, "gpu_memory_utilization", None + ), + "swap_space": getattr(vllm_config.cache_config, "swap_space", None), + "enable_prefix_caching": getattr( + vllm_config.cache_config, "enable_prefix_caching", None + ), + }, + } + + return inference_info + + def get_inference_version(self) -> str: + """Get vLLM version information. + + Returns: + str: vLLM version string + """ + return VLLM_VERSION + + @_lmcache_nvtx_annotate + def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"): + for layer_name in forward_context.no_compile_layers: + attn_layer = forward_context.no_compile_layers[layer_name] + if not hasattr(attn_layer, "kv_cache"): + logger.debug("The layer %s does not have kv_cache, skip it", layer_name) + continue + + if layer_name not in self.kv_caches: + self.kv_caches[layer_name] = attn_layer.kv_cache[ + forward_context.virtual_engine + ] + + #################### + # Worker side APIs + #################### + + @_lmcache_nvtx_annotate + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + """Start loading the KV cache from the connector buffer to vLLM's + paged KV buffer. + + Args: + forward_context (ForwardContext): the forward context. + + Note: + The number of elements in kv_caches and layer_names should be + the same. + """ + self.current_layer = 0 + + if len(self.kv_caches) == 0: + self._init_kv_caches_from_forward_context(forward_context) + + metadata = self._parent._get_connector_metadata() + assert isinstance(metadata, LMCacheConnectorMetadata) + + assert len(self.kv_caches) > 0 + kvcaches = list(self.kv_caches.values()) + + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + logger.debug("In connector.start_load_kv, but the attn_metadata is None") + return + + assert self.lmcache_engine is not None + + self.lmcache_engine.post_init(kvcaches=kvcaches) + + self.layerwise_retrievers = [] + + for idx, request in enumerate(metadata.requests): + if request.load_spec is None: + continue + last_idx = idx + + for idx, request in enumerate(metadata.requests): + if request.load_spec is None: + continue + + tokens = request.token_ids + # TODO: have a pre-allocated buffer to hold the slot_mappings + slot_mapping = request.slot_mapping.cuda() + assert len(tokens) == len(slot_mapping) + + self._stats_monitor.update_interval_vllm_hit_tokens( + request.load_spec.vllm_cached_tokens + ) + token_mask = torch.ones(len(tokens), dtype=torch.bool) + masked_token_count = ( + request.load_spec.vllm_cached_tokens + // self._lmcache_chunk_size + * self._lmcache_chunk_size + ) + token_mask[:masked_token_count] = False + + lmcache_cached_tokens = request.load_spec.lmcache_cached_tokens + if self.use_layerwise: + sync = idx == last_idx + # NOTE(Jiayi): Perform blending before layerwise prefix caching + if self.enable_blending: + # TODO(Jiayi): Need to make prefix caching and blending + # compatible + self.blender.blend( + tokens[:lmcache_cached_tokens], + token_mask[:lmcache_cached_tokens], + kvcaches=kvcaches, + slot_mapping=slot_mapping[:lmcache_cached_tokens], + ) + else: + layerwise_retriever = self.lmcache_engine.retrieve_layer( + tokens[:lmcache_cached_tokens], + token_mask[:lmcache_cached_tokens], + kvcaches=kvcaches, + slot_mapping=slot_mapping[:lmcache_cached_tokens], + sync=sync, + ) + # NOTE: retrieve for two layers at the first layer + next(layerwise_retriever) + next(layerwise_retriever) + self.layerwise_retrievers.append(layerwise_retriever) + else: + ret_token_mask = self.lmcache_engine.retrieve( + tokens[:lmcache_cached_tokens], + token_mask[:lmcache_cached_tokens], + kvcaches=kvcaches, + slot_mapping=slot_mapping[:lmcache_cached_tokens], + request_configs=request.request_configs, + req_id=request.req_id, + ) + + # Check the result + num_retrieved_tokens = ret_token_mask.sum().item() + num_expected_tokens = ( + lmcache_cached_tokens - request.load_spec.vllm_cached_tokens + ) + if num_retrieved_tokens < num_expected_tokens: + logger.error( + "The number of retrieved tokens is less than the " + "expected number of tokens! This should not happen!" + ) + logger.error( + "Num retrieved tokens: %d, num expected tokens: %d", + num_retrieved_tokens, + num_expected_tokens, + ) + + @_lmcache_nvtx_annotate + def wait_for_layer_load(self, layer_name: str) -> None: + """Blocking until the KV for a specific layer is loaded into vLLM's + paged buffer. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + if self.layerwise_retrievers: + logger.debug("Waiting for layer %s to be loaded", self.current_layer) + + # Wait for the layer to be loaded + for layerwise_retriever in self.layerwise_retrievers: + ret_token_mask = next(layerwise_retriever) + + if self.current_layer == self.num_layers - 1: + assert ret_token_mask is not None + num_retrieved_tokens = ret_token_mask.sum().item() + logger.info("Retrieved %s tokens", num_retrieved_tokens) + + return + + @_lmcache_nvtx_annotate + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + """Start saving the a layer of KV cache from vLLM's paged buffer + to the connector. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + """ + assert self.lmcache_engine is not None + + if not self.use_layerwise: + return + + if self.kv_role == "kv_consumer": + # Don't do save if the role is kv_consumer + return + if self._parent._connector_metadata is None: + logger.warning( + "In connector.save_kv_layer, but the connector metadata is None" + ) + return + connector_metadata = self._parent._get_connector_metadata() + assert isinstance(connector_metadata, LMCacheConnectorMetadata) + + assert len(self.kv_caches) > 0 + + kvcaches = list(self.kv_caches.values()) + if self.current_layer == 0: + self.layerwise_storers = [] + + is_first = True + + for idx, request in enumerate(connector_metadata.requests): + save_spec = request.save_spec + if save_spec is None or not save_spec.can_save: + continue + + token_ids = request.token_ids + assert isinstance(token_ids, list) + + slot_mapping = request.slot_mapping + assert isinstance(slot_mapping, torch.Tensor) + assert len(slot_mapping) == len(token_ids) + + # TODO: have a pre-allocated buffer to hold the slot_mappings + slot_mapping = slot_mapping.cuda() + + if self.kv_role == "kv_producer": + skip_leading_tokens = 0 + else: + skip_leading_tokens = save_spec.skip_leading_tokens + + if skip_leading_tokens == len(token_ids): + continue # skip this request + # Align to lmcache chunk size + skip_leading_tokens = ( + skip_leading_tokens + // self._lmcache_chunk_size + * self._lmcache_chunk_size + ) + + store_mask = torch.ones(len(token_ids), dtype=torch.bool) + store_mask[:skip_leading_tokens] = False + + logger.info( + "Storing KV cache for %d out of %d tokens " + "(skip_leading_tokens=%d) for request %s", + len(token_ids) - skip_leading_tokens, + len(token_ids), + skip_leading_tokens, + request.req_id, + ) + + # TODO (Jiayi): need to make layerwise storing + # compatible with disagg spec + layerwise_storer = self.lmcache_engine.store_layer( + token_ids, + mask=store_mask, + kvcaches=kvcaches, + slot_mapping=slot_mapping, + offset=skip_leading_tokens, + sync=is_first, + ) + self.layerwise_storers.append(layerwise_storer) + if is_first: + is_first = False + + for layerwise_storer in self.layerwise_storers: + next(layerwise_storer) + + self.current_layer += 1 + + @_lmcache_nvtx_annotate + def wait_for_save(self): + """Blocking until the KV cache is saved to the connector buffer.""" + + connector_metadata = self._parent._get_connector_metadata() + assert isinstance(connector_metadata, LMCacheConnectorMetadata) + + self.lmcache_engine.lookup_unpin( # type: ignore + connector_metadata.lookup_requests_in_step + ) + + if self.kv_role == "kv_consumer": + # Don't do save if the role is kv_consumer + return + + if self.use_layerwise: + for layerwise_storer in self.layerwise_storers: + next(layerwise_storer) + return + + assert len(self.kv_caches) > 0 + kvcaches = list(self.kv_caches.values()) + + assert self.lmcache_engine is not None + + for request in connector_metadata.requests: + save_spec = request.save_spec + if ( + save_spec is None or not save_spec.can_save + ) and self.kv_role != "kv_producer": + continue + + token_ids = request.token_ids + + slot_mapping = request.slot_mapping + assert isinstance(slot_mapping, torch.Tensor) + assert len(slot_mapping) == len(token_ids) + assert save_spec is not None + + # TODO: have a pre-allocated buffer to hold the slot_mappings + slot_mapping = slot_mapping.cuda() + + skip_leading_tokens = save_spec.skip_leading_tokens + if self.kv_role == "kv_producer": + assert request.disagg_spec is not None + skip_leading_tokens = min( + skip_leading_tokens, request.disagg_spec.num_transferred_tokens + ) + + if skip_leading_tokens == len(token_ids): + continue # skip this request + # Align to lmcache chunk size + skip_leading_tokens = ( + skip_leading_tokens + // self._lmcache_chunk_size + * self._lmcache_chunk_size + ) + + store_mask = torch.ones(len(token_ids), dtype=torch.bool) + store_mask[:skip_leading_tokens] = False + + logger.info( + "Storing KV cache for %d out of %d tokens " + "(skip_leading_tokens=%d) for request %s", + len(token_ids) - skip_leading_tokens, + len(token_ids), + skip_leading_tokens, + request.req_id, + ) + + is_last_prefill = request.is_last_prefill + if is_last_prefill: + if request.disagg_spec: + request.disagg_spec.is_last_prefill = True + else: + token_len = len(token_ids) + aligned_token_len = ( + token_len // self._lmcache_chunk_size * self._lmcache_chunk_size + ) + token_ids = token_ids[:aligned_token_len] + store_mask = store_mask[:aligned_token_len] + slot_mapping = slot_mapping[:aligned_token_len] + + self.lmcache_engine.store( + token_ids, + mask=store_mask, + kvcaches=kvcaches, + slot_mapping=slot_mapping, + offset=skip_leading_tokens, + transfer_spec=request.disagg_spec, + request_configs=request.request_configs, + ) + + # NOTE(Jiayi): We assume all tokens are saved + save_spec.skip_leading_tokens = len(token_ids) + if request.disagg_spec: + request.disagg_spec.num_transferred_tokens = len(token_ids) + + @_lmcache_nvtx_annotate + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[set[str] | None, set[str] | None]: + return None, None + + ################### + # Scheduler side APIs + #################### + + @_lmcache_nvtx_annotate + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> int | None: + """ + Check for external KV cache hit. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + if self.kv_role == "kv_producer" and not hasattr( + self.lookup_client, "supports_producer_reuse" + ): + return 0 + + self._requests_priority[request.request_id] = request.priority + + token_ids = request.prompt_token_ids + + # If the request has multimodal hashes, apply them to the token ids + mm_hashes, mm_positions = extract_mm_features(request) + if mm_hashes and mm_positions: + # TODO(Jiayi): Optimize this + token_ids_tensor = torch.tensor(request.prompt_token_ids) + apply_mm_hashes_to_token_ids(token_ids_tensor, mm_hashes, mm_positions) + token_ids = token_ids_tensor.tolist() + + if request.sampling_params: + request_configs = extract_request_configs(request.sampling_params) + else: + request_configs = None + + if self.skip_last_n_tokens > 0: + assert token_ids is not None + token_ids = token_ids[: -self.skip_last_n_tokens] + lookup_id = request.request_id if self.async_loading else str(uuid.uuid4()) + + self._lookup_requests_in_step.append(lookup_id) + + num_external_hit_tokens = self.lookup_client.lookup( + token_ids, + lookup_id=lookup_id, + request_configs=request_configs, + ) + + if num_external_hit_tokens is None: + logger.info( + "Reqid: %s, Total tokens %d, LMCache hit tokens: None.", + request.request_id, + request.num_tokens, + ) + return None + + # When prompt length is divisible by the block size and all + # blocks are cached, we need to recompute the last token. + # This will be removed in the future if vLLM's scheduler provides + # a better support for this case. + need_to_allocate = num_external_hit_tokens - num_computed_tokens + + # In, full-prompt-hit case, we need to recompute the last token + if num_external_hit_tokens == request.num_tokens: + need_to_allocate -= 1 + + logger.info( + "Reqid: %s, Total tokens %d, LMCache hit tokens: %d, need to load: %d", + request.request_id, + request.num_tokens, + num_external_hit_tokens, + need_to_allocate, + ) + + self.load_specs[request.request_id] = LoadSpec( + vllm_cached_tokens=num_computed_tokens, + lmcache_cached_tokens=num_external_hit_tokens, + can_load=False, + ) + + if need_to_allocate <= 0: + return 0 + + return need_to_allocate + + @_lmcache_nvtx_annotate + def update_state_after_alloc(self, request: "Request", num_external_tokens: int): + """ + Update KVConnector state after temporary buffer alloc. + + For SharedStorageConnector, update _request_needs_load + if the CacheManager this allocated blocks for us. + """ + + kv_transfer_params = ( + request.kv_transfer_params + if hasattr(request, "kv_transfer_params") + else None + ) + + if kv_transfer_params is not None and "disagg_spec" in kv_transfer_params: + req_disagg_spec = kv_transfer_params["disagg_spec"] + + receiver_id = req_disagg_spec["receiver_host"] + str( + req_disagg_spec["receiver_init_port"] + ) + + disagg_spec = DisaggSpec( + req_id=req_disagg_spec["req_id"], + receiver_id=receiver_id, + receiver_host=req_disagg_spec["receiver_host"], + receiver_init_port=req_disagg_spec["receiver_init_port"], + receiver_alloc_port=req_disagg_spec["receiver_alloc_port"], + ) + + tmp_disagg_tracker[request.request_id] = disagg_spec + self._unfinished_requests[request.request_id] = request + + if request.request_id not in self.load_specs: + # No KV tokens from external KV cache, return + return + + if num_external_tokens == 0: + # No need to load anything + self.load_specs[request.request_id].can_load = False + return + + # Only check for non-prompt-hit case + if ( + self.load_specs[request.request_id].lmcache_cached_tokens + != request.num_tokens + ): + assert ( + num_external_tokens > 0 + and num_external_tokens + == self.load_specs[request.request_id].lmcache_cached_tokens + - self.load_specs[request.request_id].vllm_cached_tokens + ), ( + f"Mismatch in number of tokens: {num_external_tokens} vs " + f"{self.load_specs[request.request_id].lmcache_cached_tokens} -" + f" {self.load_specs[request.request_id].vllm_cached_tokens}" + f" for request {request.request_id}" + ) + + self.load_specs[request.request_id].can_load = True + + @_lmcache_nvtx_annotate + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + """Attach the connector metadata to the request object. + + This function should NOT modify other fields in the scheduler_output + except the `kv_connector_metadata` field. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + + force_skip_save = self.kv_role == "kv_consumer" or self.force_skip_save + + meta = LMCacheConnectorMetadata() + + # set and update lookup requests for unpin + meta.lookup_requests_in_step = self._lookup_requests_in_step + self._lookup_requests_in_step = [] + + for finished_req_id in scheduler_output.finished_req_ids: + self._request_trackers.pop(finished_req_id, None) + self._unfinished_requests.pop(finished_req_id, None) + + for request in scheduler_output.scheduled_new_reqs: + # Right now, we only load KV for new requests + load_spec = self.load_specs.pop(request.req_id, None) + num_tokens_to_compute = ( + request.num_computed_tokens + + scheduler_output.num_scheduled_tokens[request.req_id] + ) + lmcache_cached_tokens = 0 + if load_spec is not None: + lmcache_cached_tokens = load_spec.lmcache_cached_tokens + request_priority = self._requests_priority.pop(request.req_id, 0) + + skip_save = force_skip_save or ( + self.config.priority_limit is not None + and request_priority > self.config.priority_limit + ) + + request_tracker = RequestTracker.from_new_request( + self.config, + request, + num_tokens_to_compute, + lmcache_cached_tokens, + skip_save, + ) + self._request_trackers[request.req_id] = request_tracker + + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + self._lmcache_chunk_size, + load_spec=load_spec, + discard_partial_chunks=self._discard_partial_chunks, + save_decode_cache=self._save_decode_cache, + ) + if req_meta is not None: + meta.add_request(req_meta) + + cached_reqs = scheduler_output.scheduled_cached_reqs + + # NOTE: For backward compatibility with vllm version < 0.9.2, + # In the latest vllm version, the type of scheduled_cached_reqs has + # changed from list to object `CachedRequestData` + if isinstance(cached_reqs, list): + for i, req in enumerate(cached_reqs): + request_tracker = self._request_trackers[req.req_id] + request_tracker.update(req.new_token_ids, req.new_block_ids) + + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + self._lmcache_chunk_size, + load_spec=None, + discard_partial_chunks=self._discard_partial_chunks, + ) + if req_meta is not None: + meta.add_request(req_meta) + return meta + + for i, req_id in enumerate(cached_reqs.req_ids): + request_tracker = self._request_trackers[req_id] + num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] + if cached_request := self._unfinished_requests.get(req_id): + num_current_tokens = len(request_tracker.token_ids) + new_token_ids = cached_request.all_token_ids[ + num_current_tokens : num_current_tokens + num_new_tokens + ] + else: + raise ValueError( + f"Request {req_id} is not in _unfinished_requests, " + f"but it is scheduled to be cached" + ) + new_block_ids = cached_reqs.new_block_ids[i] + + request_tracker.update(new_token_ids, new_block_ids) + + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + self._lmcache_chunk_size, + load_spec=None, + discard_partial_chunks=self._discard_partial_chunks, + save_decode_cache=self._save_decode_cache, + ) + if req_meta is not None: + meta.add_request(req_meta) + + return meta + + @_lmcache_nvtx_annotate + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + params = ( + request.kv_transfer_params + if hasattr(request, "kv_transfer_params") + else None + ) + return_params = None + + # NOTE: Used to stream back the first token + # for disagg prefill + if params is not None and "ret_first_tok" in params: + return_params = { + "first_tok": request._output_token_ids[0], + } + + return False, return_params diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py b/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py index 879cc9a23581..21002fe572c5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass, field -from typing import Any, Optional, Union +from typing import Any from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory @@ -32,7 +32,7 @@ def aggregate(self, other: "KVConnectorStats") -> "KVConnectorStats": """ raise NotImplementedError - def reduce(self) -> dict[str, Union[int, float]]: + def reduce(self) -> dict[str, int | float]: """ Reduce the observations collected during a time interval to one or more representative values (eg avg/median/sum of the series). @@ -58,7 +58,7 @@ def __init__(self, kv_tranfer_config: KVTransferConfig): self.reset() def reset(self): - self.transfer_stats_accumulator: Optional[KVConnectorStats] = None + self.transfer_stats_accumulator: KVConnectorStats | None = None def observe(self, transfer_stats_data: dict[str, Any]): # Should not be called when a KVConnector is not configured. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index e48d4ccd1d6c..c1a2ac012415 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -3,7 +3,7 @@ import copy from collections.abc import Iterable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import torch @@ -33,7 +33,7 @@ @dataclass class MultiKVConnectorMetadata(KVConnectorMetadata): metadata: tuple[KVConnectorMetadata, ...] - extra_async_saves: Optional[dict[str, int]] = None + extra_async_saves: dict[str, int] | None = None @dataclass @@ -86,13 +86,11 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) self._connectors: list[KVConnectorBase_V1] = [] self._ktc_kv_transfer_config = [] - ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "connectors" - ) + ktcs = self._kv_transfer_config.kv_connector_extra_config.get("connectors") assert ktcs is not None for ktc in ktcs: temp_config = copy.copy(vllm_config) - engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id) + engine_id = ktc.get("engine_id", self._kv_transfer_config.engine_id) temp_config.kv_transfer_config = KVTransferConfig( **ktc, engine_id=engine_id ) @@ -130,7 +128,7 @@ def clear_connector_metadata(self) -> None: c.clear_connector_metadata() def shutdown(self): - exception: Optional[Exception] = None + exception: Exception | None = None for c in self._connectors: try: c.shutdown() @@ -169,7 +167,7 @@ def wait_for_save(self): def get_finished( self, finished_req_ids: set[str] - ) -> tuple[Optional[set[str]], Optional[set[str]]]: + ) -> tuple[set[str] | None, set[str] | None]: finished_sending: set[str] = set() finished_recving: set[str] = set() for c in self._connectors: @@ -207,7 +205,7 @@ def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, - ) -> tuple[Optional[int], bool]: + ) -> tuple[int | None, bool]: to_return = (0, False) for i, c in enumerate(self._connectors): toks, load_async = c.get_num_new_matched_tokens( @@ -258,7 +256,7 @@ def request_finished( self, request: "Request", blocks: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: async_saves = 0 kv_txfer_params = None for c in self._connectors: @@ -286,7 +284,7 @@ def take_events(self) -> Iterable["KVCacheEvent"]: yield from c.take_events() @classmethod - def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> Optional[str]: + def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None: """ Get the required KV cache layout for this connector. Args: @@ -296,6 +294,7 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> Optional[str] str: the required KV cache layout. e.g. HND, or NHD. None if the connector does not require a specific layout. """ + assert vllm_config.kv_transfer_config is not None ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "connectors" ) @@ -323,17 +322,47 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> Optional[str] @classmethod def build_kv_connector_stats( - cls, data: Optional[dict[str, Any]] = None - ) -> Optional[KVConnectorStats]: - return ( - MultiKVConnectorStats(data=data) - if data is not None - else MultiKVConnectorStats() - ) + cls, data: dict[str, Any] | None = None + ) -> KVConnectorStats | None: + if data is None: + return MultiKVConnectorStats() + + # data is a dict mapping connector name to their stats data. + # The stats data can be either: + # 1. Already-instantiated KVConnectorStats objects (same process) + # 2. Serialized dicts (cross-process after serialization) + # We need to reconstruct proper KVConnectorStats objects from dicts + reconstructed_data = {} + for connector_name, stats_value in data.items(): + # If already a KVConnectorStats object, use it directly + if isinstance(stats_value, KVConnectorStats): + reconstructed_data[connector_name] = stats_value + continue + + # Otherwise, reconstruct from serialized dict + # Get the connector class to reconstruct its stats + connector_cls = KVConnectorFactory.get_connector_class_by_name( + connector_name + ) + + # stats_value is the serialized dataclass which contains {'data': {...}} + # We need to extract the inner 'data' field to avoid double-nesting + assert isinstance(stats_value, dict) and "data" in stats_value, ( + f"Expected a dict with a 'data' field, got {stats_value}" + ) + inner_data = stats_value["data"] + + # Use the connector's build_kv_connector_stats to reconstruct + if reconstructed_stats := connector_cls.build_kv_connector_stats( + data=inner_data + ): + reconstructed_data[connector_name] = reconstructed_stats + + return MultiKVConnectorStats(data=reconstructed_data) - def get_kv_connector_stats(self) -> Optional[MultiKVConnectorStats]: + def get_kv_connector_stats(self) -> MultiKVConnectorStats | None: # Group connector stats by connector type. - stats_by_connector: Optional[MultiKVConnectorStats] = None + stats_by_connector: MultiKVConnectorStats | None = None for c in self._connectors: stats = c.get_kv_connector_stats() if stats is None: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index e3e3389fd164..72fcb5cd5bb7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -13,7 +13,7 @@ from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any import msgspec import numpy as np @@ -21,8 +21,8 @@ import zmq from vllm import envs -from vllm.attention.backends.registry import _Backend -from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( CopyBlocksOp, @@ -36,11 +36,10 @@ get_tensor_model_parallel_world_size, get_tp_group, ) -from vllm.distributed.utils import divide from vllm.forward_context import ForwardContext from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import make_zmq_path, make_zmq_socket +from vllm.utils.network_utils import make_zmq_path, make_zmq_socket from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput @@ -68,6 +67,7 @@ NixlWrapper = None nixlXferTelemetry = None + try: from nixl._api import nixl_agent_config except ImportError: @@ -153,10 +153,10 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler: Optional[NixlConnectorScheduler] = ( + self.connector_scheduler: NixlConnectorScheduler | None = ( NixlConnectorScheduler(vllm_config, self.engine_id) ) - self.connector_worker: Optional[NixlConnectorWorker] = None + self.connector_worker: NixlConnectorWorker | None = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None self.connector_worker = NixlConnectorWorker(vllm_config, self.engine_id) @@ -189,7 +189,7 @@ def get_required_kvcache_layout(cls, vllm_config: VllmConfig): def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int - ) -> tuple[Optional[int], bool]: + ) -> tuple[int | None, bool]: assert self.connector_scheduler is not None return self.connector_scheduler.get_num_new_matched_tokens( request, num_computed_tokens @@ -214,7 +214,7 @@ def request_finished( self, request: "Request", block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) @@ -234,14 +234,20 @@ def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: assert self.connector_worker is not None return self.connector_worker.get_finished() - def get_kv_connector_stats(self) -> Optional[KVConnectorStats]: + def get_block_ids_with_load_errors(self) -> set[int]: + """Get block IDs that failed to load via NIXL.""" assert self.connector_worker is not None + return self.connector_worker.get_block_ids_with_load_errors() + + def get_kv_connector_stats(self) -> KVConnectorStats | None: + if self.connector_worker is None: + return None return self.connector_worker.get_kv_connector_stats() @classmethod def build_kv_connector_stats( - cls, data: Optional[dict[str, Any]] = None - ) -> Optional[KVConnectorStats]: + cls, data: dict[str, Any] | None = None + ) -> KVConnectorStats | None: return ( NixlKVConnectorStats(data=data) if data is not None @@ -291,6 +297,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): + vllm_config.parallel_config.data_parallel_rank * vllm_config.parallel_config.tensor_parallel_size ) + assert vllm_config.kv_transfer_config is not None self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu" logger.info("Initializing NIXL Scheduler %s", engine_id) @@ -334,7 +341,8 @@ def get_num_new_matched_tokens( if params is not None and params.get("do_remote_prefill"): # Remote prefill: get all prompt blocks from remote. - count = len(request.prompt_token_ids) - num_computed_tokens + token_ids = request.prompt_token_ids or [] + count = len(token_ids) - num_computed_tokens if count > 0: return count, True @@ -445,7 +453,7 @@ def request_finished( self, request: "Request", block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: """ Once a request is finished, determine whether request blocks should be freed now or will be sent asynchronously and freed later. @@ -454,7 +462,9 @@ def request_finished( params = request.kv_transfer_params logger.debug( - "NIXLConnector request_finished, request_status=%s, kv_transfer_params=%s", + "NIXLConnector request_finished(%s), request_status=%s, " + "kv_transfer_params=%s", + request.request_id, request.status, params, ) @@ -486,6 +496,12 @@ def request_finished( if delay_free_blocks: # Prefill request on remote. It will be read from D upon completion + logger.debug( + "NIXLConnector request_finished(%s) waiting for %d seconds " + "for remote decode to fetch blocks", + request.request_id, + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT, + ) self._reqs_need_send[request.request_id] = ( time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT ) @@ -504,6 +520,74 @@ def request_finished( class NixlConnectorWorker: """Implementation of Worker side methods""" + _POLL_TIMEOUT = 0.1 # Handshake thread polls for stop event every 100ms + + @dataclass + class TpKVTopology: + """ + Helper class for tensor parallel and KV topology information for + mapping between local and remote TP workers. + """ + + tp_size: int + tp_rank: int + remote_tp_size: dict[EngineId, int] + is_mla: bool + total_num_kv_heads: int + + def tp_ratio( + self, + remote_tp_size: int, + ) -> int: + """ + Calculate the tensor parallel ratio between local and remote TP. + We can think of it as the number of local TP workers-per-remote TP + workers. Local workers will read from the same remote TP worker in + groups of size `tp_ratio`. + """ + assert self.tp_size % remote_tp_size == 0, ( + f"Local tensor parallel size {self.tp_size} is not divisible " + f"by remote tensor parallel size {remote_tp_size}." + ) + return self.tp_size // remote_tp_size + + def tp_ratio_from_engine_id( + self, + remote_engine_id: EngineId, + ) -> int: + remote_tp_size = self.remote_tp_size[remote_engine_id] + return self.tp_ratio(remote_tp_size) + + def is_kv_replicated(self, engine_id: EngineId) -> bool: + """ + Whether the KV cache is replicated across TP workers due to the + number of TP workers being greater than the number of KV heads. + """ + tp_size = self.remote_tp_size[engine_id] + return tp_size // self.total_num_kv_heads >= 1 + + def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: + # MLA is always replicated as the hidden dim can't be split. + return self.is_mla or self.is_kv_replicated(remote_engine_id) + + def get_target_remote_rank( + self, + remote_tp_size: int, + ) -> int: + """ + Get the remote TP rank (on P) that the current local TP rank + (on D) will read from. + """ + tp_ratio = self.tp_ratio(remote_tp_size) + return self.tp_rank // tp_ratio + + def get_target_remote_rank_from_engine_id( + self, + remote_engine_id: EngineId, + ) -> int: + remote_tp_size = self.remote_tp_size[remote_engine_id] + return self.get_target_remote_rank(remote_tp_size) + def __init__(self, vllm_config: VllmConfig, engine_id: str): if NixlWrapper is None: logger.error("NIXL is not available") @@ -515,21 +599,35 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size + if vllm_config.kv_transfer_config is None: + raise ValueError("kv_transfer_config must be set for NixlConnector") + self.kv_transfer_config = vllm_config.kv_transfer_config + self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config( "backends", ["UCX"] ) # TODO temporary, once nixl allows for telemetry flag in config # (next release), we can remove this env var. os.environ["NIXL_TELEMETRY_ENABLE"] = "1" + # Agent. non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"] + # Configure NIXL num_threads to avoid UAR exhaustion on Mellanox NICs. + # Each UCX thread allocates UARs (doorbell pages) via DevX, and + # excessive NIXL UAR usage can exhaust NIC UAR space. This can cause + # components like NVSHMEM (used by DeepEP kernels) to fail during RDMA + # initialization with "mlx5dv_devx_alloc_uar" errors. + # Ref: https://network.nvidia.com/files/doc-2020/ethernet-adapters-programming-manual.pdf#page=63 + num_threads = vllm_config.kv_transfer_config.get_from_extra_config( + "num_threads", 4 + ) if nixl_agent_config is None: config = None else: config = ( nixl_agent_config(backends=self.nixl_backends) if len(non_ucx_backends) > 0 - else nixl_agent_config(num_threads=8) + else nixl_agent_config(num_threads=num_threads) ) self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config) @@ -552,6 +650,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.world_size = get_tensor_model_parallel_world_size() self.tp_group = get_tp_group() self.num_blocks = 0 + self.enable_permute_local_kv = False # KV Caches and nixl tracking data. self.device_type = current_platform.device_type @@ -571,20 +670,21 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.use_host_buffer = self.kv_buffer_device == "cpu" # support for oot platform which can't register nixl memory # type based on kv_buffer_device - self.nixl_memory_type = current_platform.get_nixl_memory_type() - if self.nixl_memory_type is None: + nixl_memory_type = current_platform.get_nixl_memory_type() + if nixl_memory_type is None: if self.kv_buffer_device == "cuda": - self.nixl_memory_type = "VRAM" + nixl_memory_type = "VRAM" elif self.kv_buffer_device == "cpu": - self.nixl_memory_type = "DRAM" - if self.nixl_memory_type is None: + nixl_memory_type = "DRAM" + if nixl_memory_type is None: raise RuntimeError( f"{self.device_type} with {self.kv_buffer_device} kv_buffer " "is not supported." ) + self.nixl_memory_type = nixl_memory_type # Note: host xfer buffer ops when use_host_buffer is True - self.copy_blocks: Optional[CopyBlocksOp] = None + self.copy_blocks: CopyBlocksOp | None = None # Map of engine_id -> kv_caches_base_addr. For TP case, each local # rank will still only pull from a single remote TP worker. @@ -614,8 +714,14 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Set of requests that have been part of a batch, regardless of status. self._reqs_to_process: set[ReqId] = set() + # invalid blocks from failed NIXL operations + self._invalid_block_ids: set[int] = set() + # requests that skipped transfer (handshake or transfer failures) + self._failed_recv_reqs: set[ReqId] = set() + # Background thread for handling new handshake requests. - self._nixl_handshake_listener_t: Optional[threading.Thread] = None + self._nixl_handshake_listener_t: threading.Thread | None = None + self._nixl_handshake_listener_stop_event: threading.Event | None = None # Background thread for initializing new NIXL handshakes. self._handshake_initiation_executor = ThreadPoolExecutor( # NIXL is not guaranteed to be thread-safe, limit 1 worker. @@ -627,7 +733,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Protects _handshake_futures and _remote_agents. self._handshake_lock = threading.RLock() - self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -635,7 +740,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # TODO(mgoin): remove this once we have hybrid memory allocator # Optimization for models with local attention (Llama 4) # List of block window sizes for each layer for local attention - self.block_window_per_layer: list[Optional[int]] = [] + self.block_window_per_layer: list[int | None] = [] self.use_mla = self.model_config.use_mla backend = get_attn_backend( @@ -650,6 +755,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self._use_flashinfer = attn_backend == _Backend.FLASHINFER self._use_pallas = attn_backend == _Backend.PALLAS self.kv_cache_layout = get_kv_cache_layout() + self.host_buffer_kv_cache_layout = self.kv_cache_layout logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected kv cache layout %s", self.kv_cache_layout) @@ -659,10 +765,19 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.xfer_stats = NixlKVConnectorStats() + self.kv_topo = self.TpKVTopology( + tp_size=self.world_size, + tp_rank=self.tp_rank, + remote_tp_size=self._tp_size, # shared state + is_mla=self.use_mla, + total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + ) + @staticmethod def _nixl_handshake_listener( metadata: NixlAgentMetadata, ready_event: threading.Event, + stop_event: threading.Event, base_port: int, tp_rank: int, ): @@ -681,7 +796,14 @@ def _nixl_handshake_listener( logger.debug("Starting listening on path: %s", path) with zmq_ctx(zmq.ROUTER, path) as sock: ready_event.set() - while True: + poller = zmq.Poller() + poller.register(sock, zmq.POLLIN) + while not stop_event.is_set(): + events = dict( + poller.poll(timeout=NixlConnectorWorker._POLL_TIMEOUT * 1000) + ) + if sock not in events: + continue identity, _, msg = sock.recv_multipart() if msg != GET_META_MSG: logger.warning("Connection listener got unexpected message %s", msg) @@ -704,8 +826,7 @@ def _nixl_handshake( # Handshake only with the remote TP rank that current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. - tp_ratio = self._tp_size[self.engine_id] // remote_tp_size - p_remote_rank = self.tp_rank // tp_ratio + p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size) path = make_zmq_path("tcp", host, port + p_remote_rank) logger.debug( "Querying metadata on path: %s at remote rank %s", path, p_remote_rank @@ -713,6 +834,8 @@ def _nixl_handshake( # Send query for the request. with zmq_ctx(zmq.REQ, path) as sock: + # Set receive timeout to 5 seconds to avoid hanging on dead server + sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds sock.send(GET_META_MSG) metadata_bytes = sock.recv() decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) @@ -753,6 +876,20 @@ def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> Non for layer_name, kv_cache in kv_caches.items(): kv_shape = kv_cache.shape kv_dtype = kv_cache.dtype + if ( + self.kv_cache_layout == "NHD" + and self.vllm_config.kv_transfer_config is not None + and self.vllm_config.kv_transfer_config.enable_permute_local_kv + ): + logger.info_once( + "'enable_permute_local_kv' flag is enabled while " + "device KV Layout is NHD. Init host buffer with" + " HND to better support Decode/Prefill TP_ratio > 1." + ) + # Since NHD will not support Decode/Prefill TP_ratio > 1, + # we can leverage host_buffer for permute + self.host_buffer_kv_cache_layout = "HND" + kv_shape = tuple(kv_shape[i] for i in [0, 1, 3, 2, 4]) xfer_buffers[layer_name] = torch.empty( kv_shape, dtype=kv_dtype, device="cpu" ) @@ -795,10 +932,20 @@ def done_callback(f: Future[dict[int, str]], eid=remote_engine_id): fut.add_done_callback(done_callback) - # TODO: handle failure state of future in the - # callback, we want to fail the request in this case. - def request_ready(_f: Future[Any], entry=(req_id, meta)): - self._ready_requests.put(entry) + # check handshake success before proceeding with request + def request_ready(f: Future[Any], entry=(req_id, meta)): + try: + # check if handshake succeeded + f.result() + self._ready_requests.put(entry) + except Exception: + # handshake failed - mark blocks as invalid + logger.exception( + "Handshake failed for request %s, marking blocks as invalid", req_id + ) + if req_meta := self._recving_metadata.get(req_id): + self._invalid_block_ids.update(req_meta.local_block_ids) + self._failed_recv_reqs.add(req_id) fut.add_done_callback(request_ready) @@ -950,13 +1097,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # TODO(mgoin): Hybrid memory allocator is currently disabled for # models with local attention (Llama 4). Can remove this once enabled. - if self.vllm_config.model_config.hf_config.model_type == "llama4": + if self.model_config.hf_config.model_type == "llama4": from transformers import Llama4TextConfig - assert isinstance( - self.vllm_config.model_config.hf_text_config, Llama4TextConfig - ) - llama4_config = self.vllm_config.model_config.hf_text_config + assert isinstance(self.model_config.hf_text_config, Llama4TextConfig) + llama4_config = self.model_config.hf_text_config no_rope_layers = llama4_config.no_rope_layers chunk_size = llama4_config.attention_chunk_size chunk_block_size = math.ceil(chunk_size / self.block_size) @@ -980,16 +1125,25 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): num_blocks=self.num_blocks, block_lens=self.block_len_per_layer, attn_backend_name=self.backend_name, - kv_cache_layout=self.kv_cache_layout, + kv_cache_layout=self.kv_cache_layout + if not self.use_host_buffer + else self.host_buffer_kv_cache_layout, ) - ready_event = threading.Event() + ready_event, stop_event = threading.Event(), threading.Event() self._nixl_handshake_listener_t = threading.Thread( target=self._nixl_handshake_listener, - args=(metadata, ready_event, self.side_channel_port, self.tp_rank), + args=( + metadata, + ready_event, + stop_event, + self.side_channel_port, + self.tp_rank, + ), daemon=True, name="nixl_handshake_listener", ) self._nixl_handshake_listener_t.start() + self._nixl_handshake_listener_stop_event = stop_event ready_event.wait() # Wait for listener ZMQ socket to be ready. def add_remote_agent( @@ -1039,87 +1193,51 @@ def add_remote_agent( engine_id = nixl_agent_meta.engine_id # TODO re-evaluate refreshing for scaling/recovery if remote_tp_rank in self._remote_agents.get(engine_id, {}): + logger.debug( + "Remote agent with engine_id %s and rank" + "%s already exchanged metadata, skip handshake.", + engine_id, + remote_tp_rank, + ) return self._remote_agents[engine_id][remote_tp_rank] + ### Register remote agent metadata if engine_id not in self._tp_size: self._tp_size[engine_id] = remote_tp_size - else: - assert self._tp_size[engine_id] == remote_tp_size - # TODO We may eventually want to skip enforcing the same attn backend. - assert nixl_agent_meta.attn_backend_name == self.backend_name remote_agent_name = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata ) - # Number of D TP workers reading from a single P TP worker. This is - # 1 when P and D `--tensor-parallel-size` match. - tp_ratio = divide(self._tp_size[self.engine_id], self._tp_size[engine_id]) - assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" - assert not self._use_pallas or tp_ratio == 1, ( - "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." - ) - # Handle tp_size>num_kv_heads: replicate KV cache. - total_num_kv_heads = self.model_config.get_total_num_kv_heads() - is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1 + replicates_kv_cache = self.kv_topo.replicates_kv_cache(engine_id) - remote_block_len = nixl_agent_meta.block_lens[0] - if self.use_mla or is_kv_replicated: - # With replicated KV cache, only the number of blocks can differ. - assert self.block_len_per_layer == nixl_agent_meta.block_lens, ( - "KV cache sizes must match between P and D when replicated" - ) - remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) - else: - # When MLA is not used, this is a list of the same block length - for block_len in nixl_agent_meta.block_lens: - assert block_len == remote_block_len, ( - "All remote layers must have the same block size" - ) - remote_block_size = remote_block_len // ( - self.slot_size_per_layer[0] * tp_ratio - ) - if self._use_flashinfer: - # With flashinfer, KV are sent in the same message. - remote_block_size //= 2 - if tp_ratio > 1: - # Heterogeneous TP expects same kv_cache_layout. - assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout - if self.device_type == "xpu": - raise ValueError("Heterogeneous TP is not supported on XPU") + # Create dst descs and xfer side handles. TP workers have same #blocks + # so we only register once per engine_id. + if engine_id not in self.dst_num_blocks: + self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks - assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( - "Remote P worker KV layer cache must be of shape [2, N, " - "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." - ) + # Keep track of remote agent kv caches base addresses. + self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr - assert self.block_size == remote_block_size, ( - "Remote P worker with different page/block size is not supported " - f"{self.block_size=}, {remote_block_size=}" - ) + self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size) - # Create dst descs and xfer side handles. TP workers have same #blocks. - if engine_id in self.dst_num_blocks: - assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks - else: - self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + # Number of D TP workers reading from a single P TP worker. This is + # 1 when P and D `--tensor-parallel-size` match. + tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id) + ### Register remote agent memory regions blocks_data = [] # With homogeneous TP, D pulls the whole kv cache from corresponding # rank. With heterogeneous TP, prepare the descriptors by splitting the # P KV cache along kv_head dim, of D worker's kv_head size (D>P). # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. - self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr - assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) rank_offset = ( - self.tp_rank % tp_ratio * kv_block_len - if not (self.use_mla or is_kv_replicated) - else 0 + self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0 ) for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] @@ -1154,6 +1272,82 @@ def add_remote_agent( return remote_agent_name + def _validate_remote_agent_handshake( + self, nixl_agent_meta: NixlAgentMetadata, remote_tp_size: int + ): + """ + Validate the remote agent handshake metadata ensuring the + invariants hold true. + """ + remote_engine_id = nixl_agent_meta.engine_id + + assert self._tp_size[remote_engine_id] == remote_tp_size + # TODO We may eventually want to skip enforcing the same attn backend. + assert nixl_agent_meta.attn_backend_name == self.backend_name + + tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) + assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" + assert not self._use_pallas or tp_ratio == 1, ( + "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." + ) + kv_cache_layout = ( + self.kv_cache_layout + if not self.use_host_buffer + else self.host_buffer_kv_cache_layout + ) + if not self.use_mla and nixl_agent_meta.kv_cache_layout != kv_cache_layout: + if ( + self.kv_transfer_config.enable_permute_local_kv + and nixl_agent_meta.kv_cache_layout == "HND" + ): + logger.info( + "Remote is HND and local is NHD, enabled additional permute " + "on local device KV." + ) + self.enable_permute_local_kv = True + else: + raise RuntimeError( + "Heterogeneous TP expects same kv_cache_layout. " + "Or enable experimental feature to use HND to NHD support by " + "setting 'enable_permute_local_kv'=True in --kv-transfer-config." + ) + + # Block len can only vary across layers when using MLA. + remote_block_len = nixl_agent_meta.block_lens[0] + if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id): + # With replicated KV cache, only the number of blocks can differ. + assert self.block_len_per_layer == nixl_agent_meta.block_lens, ( + "KV cache sizes must match between P and D when replicated" + ) + remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) + else: + # When MLA is not used, this is a list of the same block length + for block_len in nixl_agent_meta.block_lens: + assert block_len == remote_block_len, ( + "All remote layers must have the same block size" + ) + remote_block_size = remote_block_len // ( + self.slot_size_per_layer[0] * tp_ratio + ) + if self._use_flashinfer: + # With flashinfer, KV are sent in the same message. + remote_block_size //= 2 + + assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( + "Remote P worker KV layer cache must be of shape [2, N, " + "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." + ) + + assert self.block_size == remote_block_size, ( + "Remote P worker with different page/block size is not supported " + f"{self.block_size=}, {remote_block_size=}" + ) + + # TP workers have same #blocks. + assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks + + assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) + def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): """copy recved kv from host buffer to device.""" assert self.use_host_buffer @@ -1197,6 +1391,41 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata): "d2h", ) + def permute_device_kv(self, block_ids: list[int]): + """Transforms the layout of received KV cache blocks to the local format. + + This method corrects layout mismatches from direct memory copies by + permuting the tensor dimensions. + + - **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]` + - **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]` + + Args: + block_ids: A list of block IDs to update and permute. + + Implementation: + - x = blocks_to_update.reshape(src_shape) # view local kv with sender layout + - permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size + - cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back + + """ + split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer) + inv_order = [0, 2, 1, 3] + sample_cache = list(self.device_kv_caches.values())[0][0] + target_shape = list(sample_cache.shape) + target_shape[0] = -1 + src_shape = tuple(target_shape[i] for i in inv_order) + indices = torch.tensor(block_ids, device=sample_cache.device) + + for _, cache_or_caches in self.device_kv_caches.items(): + cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] + for cache in cache_list: + blocks_to_update = cache.index_select(0, indices) + permuted_blocks = blocks_to_update.reshape(src_shape).permute( + *inv_order + ) + cache.index_copy_(0, indices, permuted_blocks) + def get_finished(self) -> tuple[set[str], set[str]]: """ Get requests that are done sending or recving on this specific worker. @@ -1205,6 +1434,11 @@ def get_finished(self) -> tuple[set[str], set[str]]: """ done_sending = self._get_new_notifs() done_recving = self._pop_done_transfers(self._recving_transfers) + + # add requests that skipped transfer to done_recving + done_recving.update(self._failed_recv_reqs) + self._failed_recv_reqs.clear() + if len(done_sending) > 0 or len(done_recving) > 0: logger.debug( "Rank %s, get_finished: %s requests done sending " @@ -1214,11 +1448,17 @@ def get_finished(self) -> tuple[set[str], set[str]]: len(done_recving), ) - if self.use_host_buffer: - for req_id in done_recving: - meta = self._recving_metadata.pop(req_id) - assert meta, f"{req_id} not found in recving_metadata list" + block_ids_to_permute = [] + for req_id in done_recving: + # clean up metadata for completed requests + meta = self._recving_metadata.pop(req_id, None) + assert meta is not None, f"{req_id} not found in recving_metadata list" + if self.use_host_buffer: self.sync_recved_kv_to_device(req_id, meta) + if self.enable_permute_local_kv: + block_ids_to_permute += meta.local_block_ids + if len(block_ids_to_permute) > 0: + self.permute_device_kv(block_ids_to_permute) # Handle timeout to avoid stranding blocks on remote. now = time.perf_counter() @@ -1296,7 +1536,19 @@ def _pop_done_transfers( in_progress = True continue else: - raise RuntimeError("Transfer failed with state %s", xfer_state) + # transfer failed - mark blocks as invalid + logger.error( + "NIXL transfer failed for request %s with state %s. " + "Marking blocks as invalid.", + req_id, + xfer_state, + ) + # mark all blocks for this request as invalid + if meta := self._recving_metadata.pop(req_id, None): + self._invalid_block_ids.update(meta.local_block_ids) + self._recving_metadata.pop(req_id, None) + self.nixl_wrapper.release_xfer_handle(handle) + self.xfer_stats.record_failed_transfer() if not in_progress: done_req_ids.add(req_id) del transfers[req_id] @@ -1317,8 +1569,8 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): len(meta.local_block_ids), len(meta.remote_block_ids), ) - if self.use_host_buffer: - self._recving_metadata[req_id] = meta + # always store metadata for failure recovery + self._recving_metadata[req_id] = meta if remote_engine_id not in self._remote_agents: # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: @@ -1345,6 +1597,8 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): # Remove all requests that are not to be processed (eg aborted). for req_id in metadata.reqs_not_processed: self._reqs_to_process.discard(req_id) + # We should never get an abort after setting an expiry timer + assert req_id not in self._reqs_to_send # Add to requests that are waiting to be read and track expiration. for req_id, expiration_time in metadata.reqs_to_send.items(): @@ -1383,16 +1637,27 @@ def _read_blocks( # Number of D TP workers that will read from dst P. Propagate tp_ratio # on notification so that dst worker can wait before freeing blocks. - tp_ratio = self._tp_size[self.engine_id] // self._tp_size[dst_engine_id] + tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id) notif_id = f"{request_id}:{tp_ratio}".encode() # Full prefix cache hit: do not need to read remote blocks, # just notify P worker that we have the blocks we need. num_local_blocks = len(local_block_ids) if num_local_blocks == 0: - remote_rank = self.tp_rank // tp_ratio + remote_rank = self.kv_topo.get_target_remote_rank_from_engine_id( + dst_engine_id + ) agent_name = self._remote_agents[dst_engine_id][remote_rank] - self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) + try: + self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) + except Exception: + logger.exception( + "NIXL send_notif failed for request %s: " + "P worker blocks will be freed after timeout. " + "This may indicate network issues.", + request_id, + ) + self.xfer_stats.record_failed_notification() return # Partial prefix cache hit: just read uncomputed blocks. @@ -1454,23 +1719,38 @@ def _read_blocks( assert len(local_block_descs_ids) == len(remote_block_descs_ids) # Prepare transfer with Nixl. - handle = self.nixl_wrapper.make_prepped_xfer( - "READ", - local_xfer_side_handle, - local_block_descs_ids, - remote_xfer_side_handle, - remote_block_descs_ids, - notif_msg=notif_id, - ) + handle = None + try: + handle = self.nixl_wrapper.make_prepped_xfer( + "READ", + local_xfer_side_handle, + local_block_descs_ids, + remote_xfer_side_handle, + remote_block_descs_ids, + notif_msg=notif_id, + ) - # Begin async xfer. - self.nixl_wrapper.transfer(handle) + # Begin async xfer. + self.nixl_wrapper.transfer(handle) - # Use handle to check completion in future step(). - self._recving_transfers[request_id].append((handle, time.perf_counter())) + # Use handle to check completion in future step(). + self._recving_transfers[request_id].append((handle, time.perf_counter())) + except Exception: + logger.exception( + "NIXL transfer setup/initiation failed for request %s. " + "Marking blocks as invalid.", + request_id, + ) + # mark all blocks for this request as invalid + if meta := self._recving_metadata.get(request_id): + self._invalid_block_ids.update(meta.local_block_ids) + self.xfer_stats.record_failed_transfer() + if handle is not None: + self.nixl_wrapper.release_xfer_handle(handle) + self._failed_recv_reqs.add(request_id) def _get_block_descs_ids( - self, engine_id: str, block_ids: list[int], layer_idx: Optional[int] = None + self, engine_id: str, block_ids: list[int], layer_idx: int | None = None ) -> np.ndarray: """ Get the descs ids for a set of block ids. @@ -1516,7 +1796,7 @@ def get_backend_aware_kv_block_len(self, layer_idx: int): block_len = self.block_len_per_layer[layer_idx] return block_len - def get_kv_connector_stats(self) -> Optional[KVConnectorStats]: + def get_kv_connector_stats(self) -> KVConnectorStats | None: """ Get the KV transfer stats for the connector. """ @@ -1525,11 +1805,30 @@ def get_kv_connector_stats(self) -> Optional[KVConnectorStats]: return self.xfer_stats.clone_and_reset() return None + def get_block_ids_with_load_errors(self) -> set[int]: + """ + Return and clear the set of block IDs that failed to load. + + This is called by the scheduler to identify blocks that need + to be retried after a NIXL transfer failure. + """ + result = self._invalid_block_ids + self._invalid_block_ids = set() + return result + + def __del__(self): + self.shutdown() + def shutdown(self): """Shutdown the connector worker.""" self._handshake_initiation_executor.shutdown(wait=False) + if self._nixl_handshake_listener_stop_event is not None: + self._nixl_handshake_listener_stop_event.set() + self._nixl_handshake_listener_stop_event = None if self._nixl_handshake_listener_t is not None: - self._nixl_handshake_listener_t.join(timeout=0) + # Generous timeout to allow the thread to exit + self._nixl_handshake_listener_t.join(timeout=self._POLL_TIMEOUT * 10) + assert not self._nixl_handshake_listener_t.is_alive() self._nixl_handshake_listener_t = None for handles in self._recving_transfers.values(): for handle, _ in handles: @@ -1557,7 +1856,7 @@ def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: if socket_type not in (zmq.ROUTER, zmq.REQ): raise ValueError(f"Unexpected socket type: {socket_type}") - ctx: Optional[zmq.Context] = None + ctx: zmq.Context | None = None try: ctx = zmq.Context() # type: ignore[attr-defined] yield make_zmq_socket( @@ -1584,6 +1883,8 @@ def reset(self): "post_duration": [], "bytes_transferred": [], "num_descriptors": [], + "num_failed_transfers": [], + "num_failed_notifications": [], } def record_transfer(self, res: nixlXferTelemetry): @@ -1593,6 +1894,14 @@ def record_transfer(self, res: nixlXferTelemetry): self.data["bytes_transferred"].append(res.totalBytes) self.data["num_descriptors"].append(res.descCount) + def record_failed_transfer(self): + """Record a failed NIXL transfer operation.""" + self.data["num_failed_transfers"].append(1.0) + + def record_failed_notification(self): + """Record a failed NIXL notification (send_notif).""" + self.data["num_failed_notifications"].append(1.0) + def clone_and_reset(self) -> "NixlKVConnectorStats": old = copy.copy(self) self.reset() @@ -1609,7 +1918,7 @@ def aggregate(self, other: KVConnectorStats) -> KVConnectorStats: accumulator.extend(v) return self - def reduce(self) -> dict[str, Union[int, float]]: + def reduce(self) -> dict[str, int | float]: # Compute compact representative stats suitable for CLI logging if self.is_empty(): return { diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 745af0efba18..6d4ffc152de9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -4,7 +4,7 @@ from collections.abc import Iterable, Iterator from dataclasses import dataclass from itertools import islice -from typing import Any, Optional +from typing import Any import torch @@ -46,8 +46,8 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): spec = OffloadingSpecFactory.create_spec(vllm_config) - self.connector_scheduler: Optional[OffloadingConnectorScheduler] = None - self.connector_worker: Optional[OffloadingConnectorWorker] = None + self.connector_scheduler: OffloadingConnectorScheduler | None = None + self.connector_worker: OffloadingConnectorWorker | None = None if role == KVConnectorRole.SCHEDULER: self.connector_scheduler = OffloadingConnectorScheduler(spec) elif role == KVConnectorRole.WORKER: @@ -113,7 +113,7 @@ def request_finished( self, request: "Request", block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) @@ -148,7 +148,7 @@ def _get_block_hashes( self, req: Request, start_idx: int = 0, - end_idx: Optional[int] = None, + end_idx: int | None = None, ) -> Iterable[BlockHash]: return islice( req.block_hashes, @@ -354,7 +354,7 @@ def request_finished( self, request: Request, block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: """ Called when a request has finished, before its blocks are freed. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index 0e6693db5cd2..e47cde2614fc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import regex as re import torch @@ -75,9 +75,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) self._block_size = vllm_config.cache_config.block_size self._requests_need_load: dict[str, Any] = {} - self.config = vllm_config.kv_transfer_config - self.is_producer = self.config.is_kv_producer - self.chunked_prefill: dict[str, Any] = {} + self.is_producer = self._kv_transfer_config.is_kv_producer + self.chunked_prefill: dict[str, tuple[list[int], list[int] | None]] = {} self._rank = get_world_group().rank if role == KVConnectorRole.WORKER else 0 self._local_rank = ( @@ -87,7 +86,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.p2p_nccl_engine = ( P2pNcclEngine( local_rank=self._local_rank, - config=self.config, + config=self._kv_transfer_config, hostname="", port_offset=self._rank, ) @@ -304,7 +303,7 @@ def wait_for_save(self): def get_finished( self, finished_req_ids: set[str], **kwargs: Any - ) -> tuple[Optional[set[str]], Optional[set[str]]]: + ) -> tuple[set[str] | None, set[str] | None]: """ Notifies worker-side connector ids of requests that have finished generating tokens. @@ -346,7 +345,8 @@ def get_num_new_matched_tokens( if self.is_producer: return 0, False - num_external_tokens = len(request.prompt_token_ids) - 1 - num_computed_tokens + prompt_token_ids = request.prompt_token_ids or [] + num_external_tokens = len(prompt_token_ids) - 1 - num_computed_tokens if num_external_tokens < 0: num_external_tokens = 0 @@ -387,7 +387,7 @@ def build_connector_meta( ] num_tokens = num_scheduled_tokens + new_req.num_computed_tokens # the request's prompt is chunked prefill - if num_tokens < len(new_req.prompt_token_ids): + if num_tokens < len(new_req.prompt_token_ids or []): # 'CachedRequestData' has no attribute 'prompt_token_ids' self.chunked_prefill[new_req.req_id] = ( new_req.block_ids[0], @@ -397,7 +397,7 @@ def build_connector_meta( # the request's prompt is not chunked prefill meta.add_request( request_id=new_req.req_id, - token_ids=new_req.prompt_token_ids, + token_ids=new_req.prompt_token_ids or [], block_ids=new_req.block_ids[0], block_size=self._block_size, ) @@ -405,7 +405,7 @@ def build_connector_meta( if new_req.req_id in self._requests_need_load: meta.add_request( request_id=new_req.req_id, - token_ids=new_req.prompt_token_ids, + token_ids=new_req.prompt_token_ids or [], block_ids=new_req.block_ids[0], block_size=self._block_size, ) @@ -421,10 +421,12 @@ def build_connector_meta( num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[req_id] num_tokens = num_scheduled_tokens + num_computed_tokens assert req_id in self.chunked_prefill + assert new_block_ids is not None block_ids = new_block_ids[0] if not resumed_from_preemption: block_ids = self.chunked_prefill[req_id][0] + block_ids prompt_token_ids = self.chunked_prefill[req_id][1] + assert prompt_token_ids is not None # the request's prompt is chunked prefill again if num_tokens < len(prompt_token_ids): self.chunked_prefill[req_id] = (block_ids, prompt_token_ids) @@ -450,6 +452,7 @@ def build_connector_meta( # NOTE(rob): For resumed req, new_block_ids is all # of the block_ids for the request. + assert new_block_ids is not None block_ids = new_block_ids[0] meta.add_request( @@ -466,7 +469,7 @@ def request_finished( self, request: "Request", block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: """ Called when a request has finished, before its blocks are freed. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index cff68818ca70..3ef287817c39 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -5,11 +5,10 @@ import os import threading import time -import typing from collections import deque from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Optional +from typing import Any import msgpack import torch @@ -26,7 +25,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501 TensorMemoryPool, ) -from vllm.utils import current_stream, get_ip +from vllm.utils.network_utils import get_ip +from vllm.utils.torch_utils import current_stream logger = logging.getLogger(__name__) @@ -77,7 +77,7 @@ def __init__( config: KVTransferConfig, hostname: str = "", port_offset: int = 0, - library_path: Optional[str] = None, + library_path: str | None = None, ) -> None: self.config = config self.rank = port_offset @@ -187,7 +187,7 @@ def __init__( self.nccl_num_channels, ) - def create_connect(self, remote_address: typing.Optional[str] = None): + def create_connect(self, remote_address: str | None = None): assert remote_address is not None if remote_address not in self.socks: sock = self.context.socket(zmq.DEALER) @@ -224,7 +224,7 @@ def send_tensor( self, tensor_id: str, tensor: torch.Tensor, - remote_address: typing.Optional[str] = None, + remote_address: str | None = None, ) -> bool: if remote_address is None: with self.recv_store_cv: @@ -296,7 +296,7 @@ def send_tensor( def recv_tensor( self, tensor_id: str, - remote_address: typing.Optional[str] = None, + remote_address: str | None = None, ) -> torch.Tensor: if self.send_type == "PUT" or self.send_type == "PUT_ASYNC": start_time = time.time() @@ -527,7 +527,7 @@ def send_sync(self, item: SendQueueItem) -> bool: def get_finished( self, finished_req_ids: set[str], no_compile_layers - ) -> tuple[Optional[set[str]], Optional[set[str]]]: + ) -> tuple[set[str] | None, set[str] | None]: """ Notifies worker-side connector ids of requests that have finished generating tokens. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index a1bab4e06145..d0cd4b07c51d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -3,7 +3,7 @@ import hashlib import os from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import safetensors import torch @@ -90,11 +90,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) self._block_size = vllm_config.cache_config.block_size self._requests_need_load: dict[str, Request] = {} - transfer_config = vllm_config.kv_transfer_config - self._storage_path = transfer_config.get_from_extra_config( + self._storage_path = self._kv_transfer_config.get_from_extra_config( "shared_storage_path", "/tmp" ) - logger.info(vllm_config.kv_transfer_config) + logger.info(self._kv_transfer_config) logger.info("Shared storage path is %s", self._storage_path) def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: @@ -249,7 +248,7 @@ def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, - ) -> tuple[Optional[int], bool]: + ) -> tuple[int | None, bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. @@ -277,9 +276,8 @@ def get_num_new_matched_tokens( # Now, first num_tokens_to_check tokens are hit, we need to prepare # the metadata for the worker connector to correctly load the KV - num_tokens_to_check = align_to_block_size( - len(request.prompt_token_ids) - 1, self._block_size - ) + token_ids = request.prompt_token_ids or [] + num_tokens_to_check = align_to_block_size(len(token_ids) - 1, self._block_size) return num_tokens_to_check - num_computed_tokens, False @@ -311,13 +309,15 @@ def build_connector_meta( total_need_load = 0 for new_req in scheduler_output.scheduled_new_reqs: + token_ids = new_req.prompt_token_ids or [] + mm_hashes = [f.identifier for f in new_req.mm_features] if new_req.req_id in self._requests_need_load: meta.add_request( - token_ids=new_req.prompt_token_ids, + token_ids=token_ids, block_ids=new_req.block_ids[0], block_size=self._block_size, is_store=False, - mm_hashes=[f.identifier for f in new_req.mm_features], + mm_hashes=mm_hashes, ) total_need_load += 1 else: @@ -325,13 +325,13 @@ def build_connector_meta( # but a single request can have both store and load. # NOTE(rob): for this debug implementation, we only cache # the original prompt tokens. - if not self._found_match_for_request(new_req): + if not self._found_match_for_prompt(token_ids, mm_hashes): meta.add_request( - token_ids=new_req.prompt_token_ids, + token_ids=token_ids, block_ids=new_req.block_ids[0], block_size=self._block_size, is_store=True, - mm_hashes=[f.identifier for f in new_req.mm_features], + mm_hashes=mm_hashes, ) cached_reqs = scheduler_output.scheduled_cached_reqs @@ -355,6 +355,7 @@ def build_connector_meta( # NOTE(rob): For resumed req, new_block_ids is all # of the block_ids for the request. + assert new_block_ids is not None block_ids = new_block_ids[0] meta.add_request( @@ -379,12 +380,22 @@ def _found_match_for_request( request: "Request", ) -> bool: """Check if the cache is hit for the request.""" + return self._found_match_for_prompt( + list(request.prompt_token_ids or []), + [f.identifier for f in request.mm_features], + ) + + def _found_match_for_prompt( + self, + prompt_token_ids: list[int], + mm_hashes: list[str], + ) -> bool: num_tokens_to_check = align_to_block_size( - len(request.prompt_token_ids) - 1, self._block_size + len(prompt_token_ids) - 1, self._block_size ) foldername = self._generate_foldername_debug( - torch.tensor(request.prompt_token_ids)[:num_tokens_to_check], - [f.identifier for f in request.mm_features], + torch.tensor(prompt_token_ids)[:num_tokens_to_check], + mm_hashes, create_folder=False, ) return os.path.exists(foldername) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py index 08b683bfe23f..f48d03d0b0cd 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py @@ -14,7 +14,6 @@ """ from abc import ABC, abstractmethod -from typing import Optional import torch @@ -98,8 +97,8 @@ def insert( @abstractmethod def drop_select( - self, input_tokens: Optional[torch.Tensor], roi: Optional[torch.Tensor] - ) -> list[Optional[torch.Tensor]]: + self, input_tokens: torch.Tensor | None, roi: torch.Tensor | None + ) -> list[torch.Tensor | None]: """Select and *drop* KV cache entries from the lookup buffer. The functionality is similar to the following python statements @@ -143,7 +142,7 @@ class KVStoreBufferBase(KVCacheBufferBase): def put( self, key: str, - value: Optional[torch.Tensor], + value: torch.Tensor | None, ) -> None: """Store a key-value pair in the buffer. @@ -163,7 +162,7 @@ def put( def get( self, key: str, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """Retrieve a value from the buffer by key. Args: diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py index 44fc6d8ac5ad..7861bea1f9c5 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py @@ -10,7 +10,6 @@ import json import os from dataclasses import dataclass -from typing import Optional import torch from safetensors.torch import load as safetensors_load @@ -110,7 +109,7 @@ def close(self): def put( self, key: str, - value: Optional[torch.Tensor], + value: torch.Tensor | None, ) -> None: # A message queue needs to be introduced before making it asynchronous. if value is not None: @@ -119,7 +118,7 @@ def put( def get( self, key: str, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: # A message queue needs to be introduced before making it asynchronous. value = self._get_impl(key) return value @@ -142,7 +141,7 @@ def _put_impl( def _get_impl( self, key: str, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """Get KVCache from Mooncake Store""" try: data = self.store.get(key) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index cd58ec2e7639..f046a349874e 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -13,7 +13,6 @@ import threading from collections import deque -from typing import Optional, Union import torch @@ -46,7 +45,7 @@ def __init__( self.buffer_cv = threading.Condition() self.signal_pipe = signal_pipe self.data_pipe = data_pipe - self.request_handling_thread: Optional[threading.Thread] = None + self.request_handling_thread: threading.Thread | None = None self.normal_signal = torch.tensor([0], device="cpu") self.end_signal = None @@ -81,14 +80,14 @@ def _matches( return 0 - def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor]) -> None: + def _send_tensor_and_dec_size(self, tensor: torch.Tensor | None) -> None: assert tensor is not None, "Use self.data_pipe.send(None) instead" self.buffer_size -= tensor.element_size() * tensor.numel() if tensor.dtype == torch.bool: tensor = tensor.float() self.data_pipe.send_tensor(tensor) - def _get_element_size(self, data: Optional[Union[list, torch.Tensor]]): + def _get_element_size(self, data: list | torch.Tensor | None): if isinstance(data, torch.Tensor): return data.element_size() * data.numel() if not data: @@ -184,8 +183,8 @@ def is_buffer_available( logger.debug("Closing drop_select_handler") def drop_select( - self, input_tokens: Optional[torch.Tensor], roi: Optional[torch.Tensor] - ) -> list[Optional[torch.Tensor]]: + self, input_tokens: torch.Tensor | None, roi: torch.Tensor | None + ) -> list[torch.Tensor | None]: assert self.request_handling_thread is None, ( "drop_select should be called by the KV cache consumer " "(e.g. the decode vLLM instance)" diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py index e27c6b2101b8..1fe7a90e9a71 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/base.py +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py @@ -12,7 +12,6 @@ """ from abc import ABC, abstractmethod -from typing import Optional import torch @@ -24,7 +23,7 @@ class KVPipeBase(ABC): """ @abstractmethod - def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + def send_tensor(self, tensor: torch.Tensor | None) -> None: """Send a tensor, or None, via the pipe. Need to support sending None -- important for error handling. @@ -42,7 +41,7 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: raise NotImplementedError @abstractmethod - def recv_tensor(self) -> Optional[torch.Tensor]: + def recv_tensor(self) -> torch.Tensor | None: """Receive a tensor (can be None) from the pipeline. Returns: diff --git a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py index 65858f86aa23..542dde09abad 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py @@ -6,7 +6,6 @@ import struct from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import Optional, Union import torch import zmq @@ -16,7 +15,7 @@ from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.logger import init_logger -from vllm.utils import join_host_port, make_zmq_path, split_host_port +from vllm.utils.network_utils import join_host_port, make_zmq_path, split_host_port logger = init_logger(__name__) NONE_INT = -150886311 @@ -26,7 +25,7 @@ class MooncakeTransferEngineConfig: prefill_url: str decode_url: str - metadata_backend: Union[str, None] + metadata_backend: str | None metadata_server: str protocol: str device_name: str @@ -143,7 +142,7 @@ def initialize( metadata_server: str, protocol: str, device_name: str, - metadata_backend: Union[str, None], + metadata_backend: str | None, ) -> None: """Initialize the mooncake instance.""" if metadata_backend is None: @@ -231,19 +230,20 @@ class MooncakePipe(KVPipeBase): """MooncakeTransferEngine based Pipe implementation.""" def __init__( - self, local_rank: int, config: KVTransferConfig, device: Optional[str] = None + self, local_rank: int, config: KVTransferConfig, device: str | None = None ): """Initialize the mooncake pipe and set related parameters.""" self.config = config self.local_rank = local_rank self.kv_rank = self.config.kv_rank + assert self.kv_rank is not None if device is None: self.device = self._select_device(self.config.kv_buffer_device) else: self.device = self._select_device(device) self.transfer_engine = MooncakeTransferEngine(self.kv_rank, self.local_rank) - self.transport_thread: Optional[ThreadPoolExecutor] = None + self.transport_thread: ThreadPoolExecutor | None = None self.none_tensor = torch.tensor([NONE_INT], device=self.device) def _select_device(self, device: str) -> torch.device: @@ -267,7 +267,7 @@ def _recv_impl(self) -> torch.Tensor: data = self.transfer_engine.recv_bytes() return safetensors_load(data)["tensor"].to(self.device) - def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + def send_tensor(self, tensor: torch.Tensor | None) -> None: """Send tensor to the target process.""" if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) @@ -275,7 +275,7 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: assert len(tensor.shape) > 0 self.transport_thread.submit(self._send_impl, tensor) - def recv_tensor(self) -> Optional[torch.Tensor]: + def recv_tensor(self) -> torch.Tensor | None: """Receive tensor from other processes.""" if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py index c79b7e7e5030..526c5cd1d527 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -15,8 +15,8 @@ import threading import time +from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor -from typing import Callable, Optional import torch @@ -35,7 +35,7 @@ def __init__(self, message): super().__init__(self.message) -Metadata = dict[str, Optional[torch.Tensor]] +Metadata = dict[str, torch.Tensor | None] class PyNcclPipe(KVPipeBase): @@ -47,12 +47,13 @@ def __init__( self, local_rank: int, config: KVTransferConfig, - device: Optional[str] = None, + device: str | None = None, port_offset: int = 0, ): self.config = config self.local_rank = local_rank self.kv_rank = self.config.kv_rank + assert self.kv_rank is not None self.kv_parallel_size = self.config.kv_parallel_size if device is None: self.device = self._select_device(self.config.kv_buffer_device) @@ -77,7 +78,7 @@ def __init__( self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size # transportation-related variables - self.transport_thread: Optional[ThreadPoolExecutor] = None + self.transport_thread: ThreadPoolExecutor | None = None self.buffer_size = 0 self.buffer_size_lock = threading.Lock() self.buffer_size_thresh = self.config.kv_buffer_size @@ -115,7 +116,7 @@ def _select_device(self, device: str): else: return torch.device("cpu") - def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata: + def _make_metadata(self, tensor: torch.Tensor | None) -> Metadata: """ Create the metadata as a dictionary based on the input tensor. @@ -167,7 +168,7 @@ def _recv_metadata(self) -> Metadata: """ return self.group.recv_obj(self.target_rank_for_recv) - def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: + def _send_impl(self, tensor: torch.Tensor | None) -> None: """ The actual implementation of sending the tensor and its metadata to the target rank. @@ -181,7 +182,7 @@ def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: if tensor is not None: self.device_send_func(tensor.to(self.device), self.target_rank_for_send) - def _recv_impl(self) -> Optional[torch.Tensor]: + def _recv_impl(self) -> torch.Tensor | None: """ The actual implementation of receiving a tensor and its metadata from the target rank. @@ -198,7 +199,7 @@ def _recv_impl(self) -> Optional[torch.Tensor]: return buffer def send_tensor_wrapper( - self, tensor: Optional[torch.Tensor], tensor_size: int + self, tensor: torch.Tensor | None, tensor_size: int ) -> None: """ Wrapper for _send_impl to handle exceptions and update buffer size. @@ -228,7 +229,7 @@ def block_if_full(self): logger.debug("KV cache transfer pipe is full. Waiting...") time.sleep(0.05) - def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + def send_tensor(self, tensor: torch.Tensor | None) -> None: """ Sends a tensor and its metadata to the destination rank in a non-blocking way. @@ -251,7 +252,7 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: self.transport_thread.submit(self.send_tensor_wrapper, tensor, tensor_size) - def recv_tensor(self) -> Optional[torch.Tensor]: + def recv_tensor(self) -> torch.Tensor | None: """ Receives a tensor and its metadata from the source rank. Blocking call. diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py index f8f65f28ff6d..cabfc10e7f94 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_state.py +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from vllm import envs from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType @@ -13,7 +13,7 @@ if TYPE_CHECKING: from vllm.config import VllmConfig -_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None +_KV_CONNECTOR_AGENT: KVConnectorBaseType | None = None def get_kv_transfer_group() -> KVConnectorBaseType: @@ -27,7 +27,7 @@ def has_kv_transfer_group() -> bool: return _KV_CONNECTOR_AGENT is not None -def is_v1_kv_transfer_group(connector: Optional[KVConnectorBaseType] = None) -> bool: +def is_v1_kv_transfer_group(connector: KVConnectorBaseType | None = None) -> bool: """Check if the KV connector is the v1 connector. If the argument is None, it will check the global KV connector diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index aee5507ade46..a9b01e82562b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -28,15 +28,18 @@ import pickle import weakref from collections import namedtuple +from collections.abc import Callable from contextlib import contextmanager, nullcontext from dataclasses import dataclass from datetime import timedelta from multiprocessing import shared_memory -from typing import Any, Callable, Optional, Union +from typing import Any, Optional from unittest.mock import patch import torch import torch.distributed +import torch.distributed._functional_collectives as funcol +import torch.distributed._symmetric_memory from torch.distributed import Backend, ProcessGroup from typing_extensions import deprecated @@ -46,10 +49,10 @@ ) from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.utils import ( +from vllm.utils.import_utils import resolve_obj_by_qualname +from vllm.utils.network_utils import get_distributed_init_method +from vllm.utils.torch_utils import ( direct_register_custom_op, - get_distributed_init_method, - resolve_obj_by_qualname, supports_custom_op, ) @@ -63,7 +66,7 @@ class GraphCaptureContext: def _split_tensor_dict( - tensor_dict: dict[str, Union[torch.Tensor, Any]], + tensor_dict: dict[str, torch.Tensor | Any], ) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]: """Split the tensor dictionary into two parts: 1. A list of (key, value) pairs. If the value is a tensor, it is replaced @@ -159,6 +162,90 @@ def all_gather_fake( return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) +def patched_fused_scaled_matmul_reduce_scatter_fake( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: str, + output_shape: list[int], + bias: torch.Tensor | None = None, + result_scale: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> torch.Tensor: + # Copied from + # https://github.com/pytorch/pytorch/blob/50c338c2da905062449e4d9ac807832d1b5cd90e/torch/distributed/_symmetric_memory/__init__.py#L1189 + if A_scale.numel() > 1: + if A_scale.shape[:-1] != A.shape[:-1]: + raise ValueError( + "For row-wise scaling, the leading dims of A_scale " + "must match the leading dims of A " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + A_scale = A_scale.flatten(0, -2).contiguous() + elif A_scale.numel() != 1: + raise ValueError( + "Invalid A_scale shape " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + + C = torch._scaled_mm( + A.flatten(0, -2).contiguous(), + B, + A_scale, + B_scale, + bias, + result_scale, + out_dtype, + use_fast_accum, + ) + C = C.view(*output_shape[:-1], B.shape[1]) + res = funcol.reduce_scatter_tensor( + C, + reduce_op, + orig_scatter_dim, # need original scatter dim for 3D+ output tensor here + group_name, + ) + res = funcol.wait_tensor(res) + return res + + +def patched_fused_scaled_matmul_reduce_scatter( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: str, + output_shape: list[int], + bias: torch.Tensor | None = None, + result_scale: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> torch.Tensor: + return torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( + A, + B, + A_scale, + B_scale, + reduce_op, + orig_scatter_dim, + scatter_dim_after_maybe_reshape, + group_name, + output_shape, + bias, + result_scale, + out_dtype, + use_fast_accum, + ) + + if supports_custom_op(): direct_register_custom_op( op_name="all_reduce", @@ -178,6 +265,15 @@ def all_gather_fake( fake_impl=all_gather_fake, ) + # TODO: Remove this once the pytorch fix + # (https://github.com/pytorch/pytorch/pull/165086) gets released, + # in either 2.9.1 or 2.10 + direct_register_custom_op( + op_name="patched_fused_scaled_matmul_reduce_scatter", + op_func=patched_fused_scaled_matmul_reduce_scatter, + fake_impl=patched_fused_scaled_matmul_reduce_scatter_fake, + ) + class GroupCoordinator: """ @@ -205,17 +301,17 @@ class GroupCoordinator: cpu_group: ProcessGroup # group for CPU communication device_group: ProcessGroup # group for device communication # device communicator (if use_device_communicator=True) - device_communicator: Optional[DeviceCommunicatorBase] - mq_broadcaster: Optional[Any] # shared memory broadcaster + device_communicator: DeviceCommunicatorBase | None + mq_broadcaster: Any | None # shared memory broadcaster def __init__( self, group_ranks: list[list[int]], local_rank: int, - torch_distributed_backend: Union[str, Backend], + torch_distributed_backend: str | Backend, use_device_communicator: bool, # whether to use device communicator use_message_queue_broadcaster: bool = False, - group_name: Optional[str] = None, + group_name: str | None = None, ): group_name = group_name or "anonymous" self.unique_name = _get_unique_name(group_name) @@ -273,7 +369,7 @@ def __init__( from vllm.distributed.device_communicators.shm_broadcast import MessageQueue - self.mq_broadcaster: Optional[MessageQueue] = None + self.mq_broadcaster: MessageQueue | None = None if use_message_queue_broadcaster and self.world_size > 1: self.mq_broadcaster = MessageQueue.create_from_process_group( self.cpu_group, 1 << 22, 6 @@ -324,9 +420,7 @@ def prev_rank(self): return self.ranks[(rank_in_group - 1) % world_size] @contextmanager - def graph_capture( - self, graph_capture_context: Optional[GraphCaptureContext] = None - ): + def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None): if graph_capture_context is None: stream = torch.cuda.Stream() graph_capture_context = GraphCaptureContext(stream) @@ -407,9 +501,9 @@ def _all_gather_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor: def all_gatherv( self, - input_: Union[torch.Tensor, list[torch.Tensor]], + input_: torch.Tensor | list[torch.Tensor], dim: int = 0, - sizes: Optional[list[int]] = None, + sizes: list[int] | None = None, ): if self.device_communicator is None: raise ValueError("No device communicator found") @@ -432,7 +526,7 @@ def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: return self._reduce_scatter_out_place(input_, dim) def reduce_scatterv( - self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None + self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None ) -> torch.Tensor: if self.device_communicator is None: raise ValueError("No device communicator found") @@ -445,7 +539,7 @@ def _reduce_scatter_out_place(self, input_: torch.Tensor, dim: int) -> torch.Ten def gather( self, input_: torch.Tensor, dst: int = 0, dim: int = -1 - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """ NOTE: We assume that the input tensor is on the same device across all the ranks. @@ -474,7 +568,7 @@ def broadcast(self, input_: torch.Tensor, src: int = 0): ) return input_ - def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + def broadcast_object(self, obj: Any | None = None, src: int = 0): """Broadcast the input object. NOTE: `src` is the local rank of the source rank. """ @@ -499,7 +593,7 @@ def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): return recv[0] def broadcast_object_list( - self, obj_list: list[Any], src: int = 0, group: Optional[ProcessGroup] = None + self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None ): """Broadcast the input object list. NOTE: `src` is the local rank of the source rank. @@ -580,11 +674,11 @@ def recv_object(self, src: int) -> Any: def broadcast_tensor_dict( self, - tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None, + tensor_dict: dict[str, torch.Tensor | Any] | None = None, src: int = 0, - group: Optional[ProcessGroup] = None, - metadata_group: Optional[ProcessGroup] = None, - ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: + group: ProcessGroup | None = None, + metadata_group: ProcessGroup | None = None, + ) -> dict[str, torch.Tensor | Any] | None: """Broadcast the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ @@ -662,11 +756,11 @@ def broadcast_tensor_dict( def send_tensor_dict( self, - tensor_dict: dict[str, Union[torch.Tensor, Any]], - dst: Optional[int] = None, + tensor_dict: dict[str, torch.Tensor | Any], + dst: int | None = None, all_gather_group: Optional["GroupCoordinator"] = None, - all_gather_tensors: Optional[dict[str, bool]] = None, - ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: + all_gather_tensors: dict[str, bool] | None = None, + ) -> dict[str, torch.Tensor | Any] | None: """Send the input tensor dictionary. NOTE: `dst` is the local rank of the source rank. @@ -750,10 +844,10 @@ def send_tensor_dict( def recv_tensor_dict( self, - src: Optional[int] = None, + src: int | None = None, all_gather_group: Optional["GroupCoordinator"] = None, - all_gather_tensors: Optional[dict[str, bool]] = None, - ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: + all_gather_tensors: dict[str, bool] | None = None, + ) -> dict[str, torch.Tensor | Any] | None: """Recv the input tensor dictionary. NOTE: `src` is the local rank of the source rank. @@ -848,7 +942,7 @@ def barrier(self): """ torch.distributed.barrier(group=self.cpu_group) - def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + def send(self, tensor: torch.Tensor, dst: int | None = None) -> None: """Sends a tensor to the destination rank in a blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" if self.device_communicator is None: @@ -856,7 +950,7 @@ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: self.device_communicator.send(tensor, dst) def recv( - self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + self, size: torch.Size, dtype: torch.dtype, src: int | None = None ) -> torch.Tensor: """Receives a tensor from the source rank.""" """NOTE: `src` is the local rank of the source rank.""" @@ -902,8 +996,8 @@ def combine( return hidden_states -_WORLD: Optional[GroupCoordinator] = None -_NODE_COUNT: Optional[int] = None +_WORLD: GroupCoordinator | None = None +_NODE_COUNT: int | None = None def get_world_group() -> GroupCoordinator: @@ -928,7 +1022,7 @@ def init_model_parallel_group( local_rank: int, backend: str, use_message_queue_broadcaster: bool = False, - group_name: Optional[str] = None, + group_name: str | None = None, ) -> GroupCoordinator: return GroupCoordinator( group_ranks=group_ranks, @@ -940,7 +1034,7 @@ def init_model_parallel_group( ) -_TP: Optional[GroupCoordinator] = None +_TP: GroupCoordinator | None = None def get_tp_group() -> GroupCoordinator: @@ -957,7 +1051,7 @@ def get_tensor_model_parallel_group(): return get_tp_group() -_DCP: Optional[GroupCoordinator] = None +_DCP: GroupCoordinator | None = None def get_dcp_group() -> GroupCoordinator: @@ -968,9 +1062,9 @@ def get_dcp_group() -> GroupCoordinator: # kept for backward compatibility get_context_model_parallel_group = get_dcp_group -_PP: Optional[GroupCoordinator] = None +_PP: GroupCoordinator | None = None -_DP: Optional[GroupCoordinator] = None +_DP: GroupCoordinator | None = None def get_dp_group() -> GroupCoordinator: @@ -978,7 +1072,7 @@ def get_dp_group() -> GroupCoordinator: return _DP -_EP: Optional[GroupCoordinator] = None +_EP: GroupCoordinator | None = None def get_ep_group() -> GroupCoordinator: @@ -1036,7 +1130,7 @@ def init_distributed_environment( distributed_init_method: str = "env://", local_rank: int = -1, backend: str = "nccl", - timeout: Optional[timedelta] = None, + timeout: timedelta | None = None, ): logger.debug( "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s", @@ -1063,7 +1157,7 @@ def init_distributed_environment( ip = parallel_config.data_parallel_master_ip port = parallel_config.get_next_dp_init_port() distributed_init_method = get_distributed_init_method(ip, port) - logger.info( + logger.debug( "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP", world_size, rank, @@ -1113,8 +1207,8 @@ def init_distributed_environment( def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, - decode_context_model_parallel_size: Optional[int] = 1, - backend: Optional[str] = None, + decode_context_model_parallel_size: int | None = 1, + backend: str | None = None, ) -> None: """ Initialize model parallel groups. @@ -1228,7 +1322,7 @@ def initialize_model_parallel( group_ranks, get_world_group().local_rank, backend, group_name="ep" ) - logger.info( + logger.info_once( "rank %s in world size %s is assigned as " "DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, @@ -1243,8 +1337,8 @@ def initialize_model_parallel( def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, - decode_context_model_parallel_size: Optional[int] = 1, - backend: Optional[str] = None, + decode_context_model_parallel_size: int | None = 1, + backend: str | None = None, ) -> None: """Helper to initialize model parallel groups if they are not initialized, or ensure tensor-parallel and pipeline-parallel sizes are equal to expected @@ -1409,7 +1503,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): def in_the_same_node_as( - pg: Union[ProcessGroup, StatelessProcessGroup], source_rank: int = 0 + pg: ProcessGroup | StatelessProcessGroup, source_rank: int = 0 ) -> list[bool]: """ This is a collective operation that returns if each rank is in the same node @@ -1432,7 +1526,9 @@ def in_the_same_node_as( ranks = list(range(world_size)) # local tensor in each process to store the result - is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) + is_in_the_same_node = torch.tensor( + [0] * world_size, dtype=torch.int32, device="cpu" + ) magic_message = b"magic_message" shm = None @@ -1529,7 +1625,30 @@ def is_global_first_rank() -> bool: return True -def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int: +def is_local_first_rank() -> bool: + """ + Check if the current process is the first local rank (rank 0 on its node). + """ + try: + # prefer the initialized world group if available + global _WORLD + if _WORLD is not None: + return _WORLD.local_rank == 0 + + if not torch.distributed.is_initialized(): + return True + + # fallback to environment-provided local rank if available + # note: envs.LOCAL_RANK is set when using env:// launchers (e.g., torchrun) + try: + return int(envs.LOCAL_RANK) == 0 # type: ignore[arg-type] + except Exception: + return torch.distributed.get_rank() == 0 + except Exception: + return True + + +def _node_count(pg: ProcessGroup | StatelessProcessGroup) -> int: """ Returns the total number of nodes in the process group. diff --git a/vllm/distributed/tpu_distributed_utils.py b/vllm/distributed/tpu_distributed_utils.py index 3db25d1a1964..4ff1f0ce4410 100644 --- a/vllm/distributed/tpu_distributed_utils.py +++ b/vllm/distributed/tpu_distributed_utils.py @@ -30,9 +30,9 @@ def __init__(self, qkv_linear: nn.Module, mesh: Optional["xs.Mesh"] = None): self.q_weight: Parameter self.k_weight: Parameter self.v_weight: Parameter - self.q_bias: Optional[Parameter] - self.k_bias: Optional[Parameter] - self.v_bias: Optional[Parameter] + self.q_bias: Parameter | None + self.k_bias: Parameter | None + self.v_bias: Parameter | None self._load_weights_from_qkv_linear(qkv_linear) if mesh is not None: self._shard_weight(mesh) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index a35f28c25385..debf69c49b7d 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -15,7 +15,7 @@ from collections import deque from collections.abc import Sequence from datetime import timedelta -from typing import Any, Optional +from typing import Any import torch from torch.distributed import ProcessGroup, TCPStore @@ -29,7 +29,8 @@ import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import get_tcp_uri, is_torch_equal_or_newer +from vllm.utils.network_utils import get_tcp_uri +from vllm.utils.torch_utils import is_torch_equal_or_newer logger = init_logger(__name__) @@ -150,7 +151,7 @@ class StatelessProcessGroup: store: torch._C._distributed_c10d.Store # stores a reference to the socket so that the file descriptor stays alive - socket: Optional[socket.socket] + socket: socket.socket | None data_expiration_seconds: int = 3600 # 1 hour @@ -197,7 +198,7 @@ def recv_obj(self, src: int) -> Any: self.recv_src_counter[src] += 1 return obj - def broadcast_obj(self, obj: Optional[Any], src: int) -> Any: + def broadcast_obj(self, obj: Any | None, src: int) -> Any: """Broadcast an object from a source rank to all other ranks. It does not clean up after all ranks have received the object. Use it for limited times, e.g., for initialization. @@ -276,7 +277,7 @@ def barrier(self, timeout: float = 30.0): # Check for timeout cur_time = time.time() if cur_time - start_time > timeout: - raise RuntimeError("Barrier timed out after %f seconds", timeout) + raise RuntimeError(f"Barrier timed out after {timeout:.2f} seconds") # Check for each process for i in range(self.world_size): @@ -323,7 +324,9 @@ def barrier(self, timeout: float = 30.0): while len(processes_departed) < self.world_size: # Check for timeout if time.time() - start_time > timeout: - raise RuntimeError("Barrier departure timed out after %f s", timeout) + raise RuntimeError( + f"Barrier departure timed out after {timeout:.2f} seconds" + ) # Check for each process for i in range(self.world_size): @@ -415,7 +418,6 @@ def create( def init_gloo_process_group( - backend: Backend, prefix_store: PrefixStore, group_rank: int, group_size: int, @@ -432,7 +434,7 @@ def init_gloo_process_group( group_size, ) else: - options = ProcessGroup.Options(backend=backend) + options = ProcessGroup.Options(backend="gloo") pg = ProcessGroup( prefix_store, group_rank, @@ -504,24 +506,25 @@ def stateless_init_torch_distributed_process_group( # Use a PrefixStore to avoid accidental overrides of keys used by # different systems (e.g. RPC) in case the store is multi-tenant. prefix_store = PrefixStore(init_method, store) + try: + from vllm.platforms import current_platform - if backend == "gloo": - return init_gloo_process_group( + return current_platform.stateless_init_device_torch_dist_pg( backend=backend, prefix_store=prefix_store, group_rank=group_rank, group_size=group_size, timeout=timeout, ) - from vllm.platforms import current_platform - - return current_platform.stateless_init_device_torch_dist_pg( - backend=backend, - prefix_store=prefix_store, - group_rank=group_rank, - group_size=group_size, - timeout=timeout, - ) + except NotImplementedError: + # If platform doesn't implement stateless_init_device_torch_dist_pg, it + # will raise a NotImplementedError. In this case, we fall back to gloo. + return init_gloo_process_group( + prefix_store=prefix_store, + group_rank=group_rank, + group_size=group_size, + timeout=timeout, + ) def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e01f2d32d914..c0ea84b6e4e8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -7,15 +7,16 @@ import functools import json import sys +from collections.abc import Callable from dataclasses import MISSING, dataclass, fields, is_dataclass from itertools import permutations +from types import UnionType from typing import ( TYPE_CHECKING, Annotated, Any, - Callable, Literal, - Optional, + TypeAlias, TypeVar, Union, cast, @@ -27,47 +28,47 @@ import regex as re import torch from pydantic import TypeAdapter, ValidationError +from pydantic.fields import FieldInfo from typing_extensions import TypeIs, deprecated import vllm.envs as envs +from vllm.attention.backends.registry import _Backend from vllm.config import ( - BlockSize, CacheConfig, - CacheDType, CompilationConfig, ConfigType, - ConvertOption, - DetailedTraceModules, - Device, DeviceConfig, - DistributedExecutorBackend, EPLBConfig, - HfOverrides, KVEventsConfig, KVTransferConfig, LoadConfig, - LogprobsMode, LoRAConfig, - MambaDType, - MMEncoderTPMode, ModelConfig, - ModelDType, + MultiModalConfig, ObservabilityConfig, ParallelConfig, PoolerConfig, - PrefixCachingHashAlgo, - RunnerOption, SchedulerConfig, - SchedulerPolicy, SpeculativeConfig, StructuredOutputsConfig, - TaskOption, - TokenizerMode, VllmConfig, get_attr_docs, ) -from vllm.config.multimodal import MMCacheType, MultiModalConfig -from vllm.config.parallel import ExpertPlacementStrategy +from vllm.config.cache import BlockSize, CacheDType, MambaDType, PrefixCachingHashAlgo +from vllm.config.device import Device +from vllm.config.model import ( + ConvertOption, + HfOverrides, + LogprobsMode, + ModelDType, + RunnerOption, + TaskOption, + TokenizerMode, +) +from vllm.config.multimodal import MMCacheType, MMEncoderTPMode +from vllm.config.observability import DetailedTraceModules +from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy +from vllm.config.scheduler import SchedulerPolicy from vllm.config.utils import get_field from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform @@ -81,16 +82,18 @@ maybe_override_with_speculators, ) from vllm.transformers_utils.utils import check_gguf_file -from vllm.utils import FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor +from vllm.utils import FlexibleArgumentParser, is_in_ray_actor +from vllm.utils.mem_constants import GiB_bytes +from vllm.utils.network_utils import get_ip from vllm.v1.sample.logits_processor import LogitsProcessor if TYPE_CHECKING: - from vllm.executor.executor_base import ExecutorBase from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.model_loader import LoadFormats from vllm.usage.usage_lib import UsageContext + from vllm.v1.executor import Executor else: - ExecutorBase = Any + Executor = Any QuantizationMethods = Any LoadFormats = Any UsageContext = Any @@ -99,8 +102,8 @@ # object is used to allow for special typing forms T = TypeVar("T") -TypeHint = Union[type[Any], object] -TypeHintT = Union[type[T], object] +TypeHint: TypeAlias = type[Any] | object +TypeHintT: TypeAlias = type[T] | object def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: @@ -115,8 +118,8 @@ def _parse_type(val: str) -> T: return _parse_type -def optional_type(return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: - def _optional_type(val: str) -> Optional[T]: +def optional_type(return_type: Callable[[str], T]) -> Callable[[str], T | None]: + def _optional_type(val: str) -> T | None: if val == "" or val == "None": return None return parse_type(return_type)(val) @@ -124,7 +127,7 @@ def _optional_type(val: str) -> Optional[T]: return _optional_type -def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]: +def union_dict_and_str(val: str) -> str | dict[str, str] | None: if not re.match(r"(?s)^\s*{.*}\s*$", val): return str(val) return optional_type(json.loads)(val) @@ -162,6 +165,31 @@ def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]: return {"type": option_type, kwarg: sorted(options)} +def collection_to_kwargs(type_hints: set[TypeHint], type: TypeHint) -> dict[str, Any]: + type_hint = get_type(type_hints, type) + types = get_args(type_hint) + elem_type = types[0] + + # Handle Ellipsis + assert all(t is elem_type for t in types if t is not Ellipsis), ( + f"All non-Ellipsis elements must be of the same type. Got {types}." + ) + + # Handle Union types + if get_origin(elem_type) in {Union, UnionType}: + # Union for Union[X, Y] and UnionType for X | Y + assert str in get_args(elem_type), ( + "If element can have multiple types, one must be 'str' " + f"(i.e. 'list[int | str]'). Got {elem_type}." + ) + elem_type = str + + return { + "type": elem_type, + "nargs": "+" if type is not tuple or Ellipsis in types else len(types), + } + + def is_not_builtin(type_hint: TypeHint) -> bool: """Check if the class is not a built-in type.""" return type_hint.__module__ != "builtins" @@ -175,7 +203,8 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]: if origin is Annotated: type_hints.update(get_type_hints(args[0])) - elif origin is Union: + elif origin in {Union, UnionType}: + # Union for Union[X, Y] and UnionType for X | Y for arg in args: type_hints.update(get_type_hints(arg)) else: @@ -196,7 +225,7 @@ def is_online_quantization(quantization: Any) -> bool: @functools.lru_cache(maxsize=30) -def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: +def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]: # Save time only getting attr docs if we're generating help text cls_docs = get_attr_docs(cls) if NEEDS_HELP else {} kwargs = {} @@ -211,6 +240,13 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: # Get the default value of the field if field.default is not MISSING: default = field.default + # Handle pydantic.Field defaults + if isinstance(default, FieldInfo): + default = ( + default.default + if default.default_factory is None + else default.default_factory() + ) elif field.default_factory is not MISSING: default = field.default_factory() @@ -243,25 +279,11 @@ def parse_dataclass(val: str, cls=dataclass_cls) -> Any: elif contains_type(type_hints, Literal): kwargs[name].update(literal_to_kwargs(type_hints)) elif contains_type(type_hints, tuple): - type_hint = get_type(type_hints, tuple) - types = get_args(type_hint) - tuple_type = types[0] - assert all(t is tuple_type for t in types if t is not Ellipsis), ( - "All non-Ellipsis tuple elements must be of the same " - f"type. Got {types}." - ) - kwargs[name]["type"] = tuple_type - kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types) + kwargs[name].update(collection_to_kwargs(type_hints, tuple)) elif contains_type(type_hints, list): - type_hint = get_type(type_hints, list) - types = get_args(type_hint) - list_type = types[0] - if get_origin(list_type) is Union: - msg = "List type must contain str if it is a Union." - assert str in get_args(list_type), msg - list_type = str - kwargs[name]["type"] = list_type - kwargs[name]["nargs"] = "+" + kwargs[name].update(collection_to_kwargs(type_hints, list)) + elif contains_type(type_hints, set): + kwargs[name].update(collection_to_kwargs(type_hints, set)) elif contains_type(type_hints, int): kwargs[name]["type"] = int # Special case for large integers @@ -304,7 +326,7 @@ def parse_dataclass(val: str, cls=dataclass_cls) -> Any: return kwargs -def get_kwargs(cls: ConfigType) -> dict[str, Any]: +def get_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]: """Return argparse kwargs for the given Config dataclass. If `--help` or `mkdocs` are not present in the command line command, the @@ -322,46 +344,53 @@ class EngineArgs: """Arguments for vLLM engine.""" model: str = ModelConfig.model - served_model_name: Optional[Union[str, list[str]]] = ModelConfig.served_model_name - tokenizer: Optional[str] = ModelConfig.tokenizer - hf_config_path: Optional[str] = ModelConfig.hf_config_path + served_model_name: str | list[str] | None = ModelConfig.served_model_name + tokenizer: str | None = ModelConfig.tokenizer + hf_config_path: str | None = ModelConfig.hf_config_path runner: RunnerOption = ModelConfig.runner convert: ConvertOption = ModelConfig.convert - task: Optional[TaskOption] = ModelConfig.task + task: TaskOption | None = ModelConfig.task skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode trust_remote_code: bool = ModelConfig.trust_remote_code allowed_local_media_path: str = ModelConfig.allowed_local_media_path - allowed_media_domains: Optional[list[str]] = ModelConfig.allowed_media_domains - download_dir: Optional[str] = LoadConfig.download_dir + allowed_media_domains: list[str] | None = ModelConfig.allowed_media_domains + download_dir: str | None = LoadConfig.download_dir safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy - load_format: Union[str, LoadFormats] = LoadConfig.load_format + load_format: str | LoadFormats = LoadConfig.load_format config_format: str = ModelConfig.config_format dtype: ModelDType = ModelConfig.dtype kv_cache_dtype: CacheDType = CacheConfig.cache_dtype - seed: Optional[int] = ModelConfig.seed - max_model_len: Optional[int] = ModelConfig.max_model_len - cuda_graph_sizes: list[int] = get_field(SchedulerConfig, "cuda_graph_sizes") + seed: int | None = ModelConfig.seed + max_model_len: int | None = ModelConfig.max_model_len + cuda_graph_sizes: list[int] | None = CompilationConfig.cudagraph_capture_sizes + cudagraph_capture_sizes: list[int] | None = ( + CompilationConfig.cudagraph_capture_sizes + ) + max_cudagraph_capture_size: int | None = get_field( + CompilationConfig, "max_cudagraph_capture_size" + ) # Note: Specifying a custom executor backend by passing a class # is intended for expert use only. The API may change without # notice. - distributed_executor_backend: Optional[ - Union[str, DistributedExecutorBackend, type[ExecutorBase]] - ] = ParallelConfig.distributed_executor_backend + distributed_executor_backend: ( + str | DistributedExecutorBackend | type[Executor] | None + ) = ParallelConfig.distributed_executor_backend # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size - data_parallel_rank: Optional[int] = None - data_parallel_start_rank: Optional[int] = None - data_parallel_size_local: Optional[int] = None - data_parallel_address: Optional[str] = None - data_parallel_rpc_port: Optional[int] = None + data_parallel_rank: int | None = None + data_parallel_start_rank: int | None = None + data_parallel_size_local: int | None = None + data_parallel_address: str | None = None + data_parallel_rpc_port: int | None = None data_parallel_hybrid_lb: bool = False data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + all2all_backend: str | None = ParallelConfig.all2all_backend enable_dbo: bool = ParallelConfig.enable_dbo dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold @@ -379,11 +408,11 @@ class EngineArgs: eplb_window_size: int = EPLBConfig.window_size eplb_step_interval: int = EPLBConfig.step_interval eplb_log_balancedness: bool = EPLBConfig.log_balancedness - max_parallel_loading_workers: Optional[int] = ( + max_parallel_loading_workers: int | None = ( ParallelConfig.max_parallel_loading_workers ) - block_size: Optional[BlockSize] = CacheConfig.block_size - enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching + block_size: BlockSize | None = CacheConfig.block_size + enable_prefix_caching: bool | None = CacheConfig.enable_prefix_caching prefix_caching_hash_algo: PrefixCachingHashAlgo = ( CacheConfig.prefix_caching_hash_algo ) @@ -392,63 +421,67 @@ class EngineArgs: swap_space: float = CacheConfig.swap_space cpu_offload_gb: float = CacheConfig.cpu_offload_gb gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization - kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes - max_num_batched_tokens: Optional[int] = SchedulerConfig.max_num_batched_tokens + kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes + max_num_batched_tokens: int | None = SchedulerConfig.max_num_batched_tokens max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold - max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs + max_num_seqs: int | None = SchedulerConfig.max_num_seqs max_logprobs: int = ModelConfig.max_logprobs logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode disable_log_stats: bool = False - revision: Optional[str] = ModelConfig.revision - code_revision: Optional[str] = ModelConfig.code_revision + aggregate_engine_logging: bool = False + revision: str | None = ModelConfig.revision + code_revision: str | None = ModelConfig.code_revision rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling") - rope_theta: Optional[float] = ModelConfig.rope_theta - hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token + rope_theta: float | None = ModelConfig.rope_theta + hf_token: bool | str | None = ModelConfig.hf_token hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides") - tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision - quantization: Optional[QuantizationMethods] = ModelConfig.quantization + tokenizer_revision: str | None = ModelConfig.tokenizer_revision + quantization: QuantizationMethods | None = ModelConfig.quantization enforce_eager: bool = ModelConfig.enforce_eager disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce - limit_mm_per_prompt: dict[str, Union[int, dict[str, int]]] = get_field( + limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field( MultiModalConfig, "limit_per_prompt" ) + enable_mm_embeds: bool = MultiModalConfig.enable_mm_embeds interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings media_io_kwargs: dict[str, dict[str, Any]] = get_field( MultiModalConfig, "media_io_kwargs" ) - mm_processor_kwargs: Optional[dict[str, Any]] = MultiModalConfig.mm_processor_kwargs + mm_processor_kwargs: dict[str, Any] | None = MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = False # DEPRECATED mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb - mm_processor_cache_type: Optional[MMCacheType] = ( + mm_processor_cache_type: MMCacheType | None = ( MultiModalConfig.mm_processor_cache_type ) mm_shm_cache_max_object_size_mb: int = ( MultiModalConfig.mm_shm_cache_max_object_size_mb ) mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode - io_processor_plugin: Optional[str] = None + mm_encoder_attn_backend: _Backend | str | None = ( + MultiModalConfig.mm_encoder_attn_backend + ) + io_processor_plugin: str | None = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling video_pruning_rate: float = MultiModalConfig.video_pruning_rate # LoRA fields enable_lora: bool = False - enable_lora_bias: bool = LoRAConfig.bias_enabled max_loras: int = LoRAConfig.max_loras max_lora_rank: int = LoRAConfig.max_lora_rank - default_mm_loras: Optional[dict[str, str]] = LoRAConfig.default_mm_loras + default_mm_loras: dict[str, str] | None = LoRAConfig.default_mm_loras fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras - max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras - lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype + max_cpu_loras: int | None = LoRAConfig.max_cpu_loras + lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight - num_gpu_blocks_override: Optional[int] = CacheConfig.num_gpu_blocks_override + num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config") - ignore_patterns: Optional[Union[str, list[str]]] = LoadConfig.ignore_patterns + ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns") - enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill + enable_chunked_prefill: bool | None = SchedulerConfig.enable_chunked_prefill disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input disable_hybrid_kv_cache_manager: bool = ( @@ -459,36 +492,37 @@ class EngineArgs: VllmConfig, "structured_outputs_config" ) reasoning_parser: str = StructuredOutputsConfig.reasoning_parser + # Deprecated guided decoding fields - guided_decoding_backend: Optional[str] = None - guided_decoding_disable_fallback: Optional[bool] = None - guided_decoding_disable_any_whitespace: Optional[bool] = None - guided_decoding_disable_additional_properties: Optional[bool] = None + guided_decoding_backend: str | None = None + guided_decoding_disable_fallback: bool | None = None + guided_decoding_disable_any_whitespace: bool | None = None + guided_decoding_disable_additional_properties: bool | None = None - logits_processor_pattern: Optional[str] = ModelConfig.logits_processor_pattern + logits_processor_pattern: str | None = ModelConfig.logits_processor_pattern - speculative_config: Optional[dict[str, Any]] = None + speculative_config: dict[str, Any] | None = None - show_hidden_metrics_for_version: Optional[str] = ( + show_hidden_metrics_for_version: str | None = ( ObservabilityConfig.show_hidden_metrics_for_version ) - otlp_traces_endpoint: Optional[str] = ObservabilityConfig.otlp_traces_endpoint - collect_detailed_traces: Optional[list[DetailedTraceModules]] = ( + otlp_traces_endpoint: str | None = ObservabilityConfig.otlp_traces_endpoint + collect_detailed_traces: list[DetailedTraceModules] | None = ( ObservabilityConfig.collect_detailed_traces ) scheduling_policy: SchedulerPolicy = SchedulerConfig.policy - scheduler_cls: Union[str, type[object]] = SchedulerConfig.scheduler_cls + scheduler_cls: str | type[object] = SchedulerConfig.scheduler_cls - pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config - override_pooler_config: Optional[Union[dict, PoolerConfig]] = ( + pooler_config: PoolerConfig | None = ModelConfig.pooler_config + override_pooler_config: dict | PoolerConfig | None = ( ModelConfig.override_pooler_config ) compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config") worker_cls: str = ParallelConfig.worker_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls - kv_transfer_config: Optional[KVTransferConfig] = None - kv_events_config: Optional[KVEventsConfig] = None + kv_transfer_config: KVTransferConfig | None = None + kv_events_config: KVEventsConfig | None = None generation_config: str = ModelConfig.generation_config enable_sleep_mode: bool = ModelConfig.enable_sleep_mode @@ -510,7 +544,7 @@ class EngineArgs: # DEPRECATED enable_multimodal_encoder_data_parallel: bool = False - logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = ( + logits_processors: list[str | type[LogitsProcessor]] | None = ( ModelConfig.logits_processors ) """Custom logitproc types""" @@ -754,6 +788,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"] ) + parallel_group.add_argument( + "--all2all-backend", **parallel_kwargs["all2all_backend"] + ) parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"]) parallel_group.add_argument( "--dbo-decode-token-threshold", @@ -866,6 +903,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: multimodal_group.add_argument( "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"] ) + multimodal_group.add_argument( + "--enable-mm-embeds", **multimodal_kwargs["enable_mm_embeds"] + ) multimodal_group.add_argument( "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"] ) @@ -888,6 +928,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: multimodal_group.add_argument( "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"] ) + multimodal_group.add_argument( + "--mm-encoder-attn-backend", + **multimodal_kwargs["mm_encoder_attn_backend"], + ) multimodal_group.add_argument( "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"] ) @@ -910,7 +954,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action=argparse.BooleanOptionalAction, help="If True, enable handling of LoRA adapters.", ) - lora_group.add_argument("--enable-lora-bias", **lora_kwargs["bias_enabled"]) lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"]) lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"]) lora_group.add_argument( @@ -970,9 +1013,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--max-long-partial-prefills", **scheduler_kwargs["max_long_partial_prefills"], ) - scheduler_group.add_argument( - "--cuda-graph-sizes", **scheduler_kwargs["cuda_graph_sizes"] - ) scheduler_group.add_argument( "--long-prefill-token-threshold", **scheduler_kwargs["long_prefill_token_threshold"], @@ -1002,6 +1042,29 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--async-scheduling", **scheduler_kwargs["async_scheduling"] ) + # Compilation arguments + compilation_kwargs = get_kwargs(CompilationConfig) + compilation_group = parser.add_argument_group( + title="CompilationConfig", + description=CompilationConfig.__doc__, + ) + compilation_group.add_argument( + "--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"] + ) + compilation_kwargs["cudagraph_capture_sizes"]["help"] = ( + "--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or v1.0.0," + " whichever is soonest. Please use --cudagraph-capture-sizes instead." + ) + compilation_group.add_argument( + "--cuda-graph-sizes", + **compilation_kwargs["cudagraph_capture_sizes"], + deprecated=True, + ) + compilation_group.add_argument( + "--max-cudagraph-capture-size", + **compilation_kwargs["max_cudagraph_capture_size"], + ) + # vLLM arguments vllm_kwargs = get_kwargs(VllmConfig) vllm_group = parser.add_argument_group( @@ -1036,6 +1099,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="Disable logging statistics.", ) + parser.add_argument( + "--aggregate-engine-logging", + action="store_true", + help="Log aggregate rather than per-engine statistics " + "when using data parallelism.", + ) return parser @classmethod @@ -1120,6 +1189,7 @@ def create_model_config(self) -> ModelConfig: enable_prompt_embeds=self.enable_prompt_embeds, served_model_name=self.served_model_name, limit_mm_per_prompt=self.limit_mm_per_prompt, + enable_mm_embeds=self.enable_mm_embeds, interleave_mm_strings=self.interleave_mm_strings, media_io_kwargs=self.media_io_kwargs, skip_mm_profiling=self.skip_mm_profiling, @@ -1129,6 +1199,7 @@ def create_model_config(self) -> ModelConfig: mm_processor_cache_type=self.mm_processor_cache_type, mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb, mm_encoder_tp_mode=self.mm_encoder_tp_mode, + mm_encoder_attn_backend=self.mm_encoder_attn_backend, pooler_config=self.pooler_config, override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, @@ -1183,7 +1254,7 @@ def create_speculative_config( target_parallel_config: ParallelConfig, enable_chunked_prefill: bool, disable_log_stats: bool, - ) -> Optional["SpeculativeConfig"]: + ) -> SpeculativeConfig | None: """Initializes and returns a SpeculativeConfig object based on `speculative_config`. @@ -1210,7 +1281,7 @@ def create_speculative_config( def create_engine_config( self, - usage_context: Optional[UsageContext] = None, + usage_context: UsageContext | None = None, headless: bool = False, ) -> VllmConfig: """ @@ -1263,7 +1334,8 @@ def create_engine_config( # Set default arguments for V1 Engine. self._set_default_args(usage_context, model_config) - # Disable chunked prefill for POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1 + # Disable chunked prefill and prefix caching for: + # POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1 if current_platform.is_cpu() and current_platform.get_cpu_architecture() in ( CpuArchEnum.POWERPC, CpuArchEnum.S390X, @@ -1276,9 +1348,16 @@ def create_engine_config( "disabling it for V1 backend." ) self.enable_chunked_prefill = False + logger.info( + "Prefix caching is not supported for ARM and POWER, " + "S390X and RISC-V CPUs; " + "disabling it for V1 backend." + ) + self.enable_prefix_caching = False + assert self.enable_chunked_prefill is not None - sliding_window: Optional[int] = None + sliding_window: int | None = None if not is_interleaved(model_config.hf_text_config): # Only set CacheConfig.sliding_window if the model is all sliding # window. Otherwise CacheConfig.sliding_window will override the @@ -1321,7 +1400,13 @@ def create_engine_config( import ray ray_runtime_env = ray.get_runtime_context().runtime_env - logger.info("Using ray runtime env: %s", ray_runtime_env) + # Avoid logging sensitive environment variables + sanitized_env = ray_runtime_env.to_dict() if ray_runtime_env else {} + if "env_vars" in sanitized_env: + sanitized_env["env_vars"] = { + k: "***" for k in sanitized_env["env_vars"] + } + logger.info("Using ray runtime env (env vars redacted): %s", sanitized_env) # Get the current placement group if Ray is initialized and # we are in a Ray actor. If so, then the placement group will be @@ -1369,8 +1454,15 @@ def create_engine_config( "data_parallel_size_local must be set to use data_parallel_hybrid_lb." ) - # Local DP size defaults to global DP size if not set. - data_parallel_size_local = self.data_parallel_size + if self.data_parallel_backend == "ray" and ( + envs.VLLM_RAY_DP_PACK_STRATEGY == "span" + ): + # Data parallel size defaults to 1 if DP ranks are spanning + # multiple nodes + data_parallel_size_local = 1 + else: + # Otherwise local DP size defaults to global DP size if not set + data_parallel_size_local = self.data_parallel_size # DP address, used in multi-node case for torch distributed group # and ZMQ sockets. @@ -1399,13 +1491,6 @@ def create_engine_config( ) if self.async_scheduling: - # Async scheduling does not work with the uniprocess backend. - if self.distributed_executor_backend is None: - self.distributed_executor_backend = "mp" - logger.info( - "Defaulting to mp-based distributed executor " - "backend for async scheduling." - ) if self.pipeline_parallel_size > 1: raise ValueError( "Async scheduling is not supported with pipeline-parallel-size > 1." @@ -1441,6 +1526,7 @@ def create_engine_config( data_parallel_backend=self.data_parallel_backend, data_parallel_hybrid_lb=self.data_parallel_hybrid_lb, enable_expert_parallel=self.enable_expert_parallel, + all2all_backend=self.all2all_backend, enable_dbo=self.enable_dbo, dbo_decode_token_threshold=self.dbo_decode_token_threshold, dbo_prefill_token_threshold=self.dbo_prefill_token_threshold, @@ -1461,6 +1547,15 @@ def create_engine_config( _api_process_rank=self._api_process_rank, ) + if self.async_scheduling and ( + parallel_config.distributed_executor_backend not in ("mp", "uni") + ): + raise ValueError( + "Currently, async scheduling only supports `mp` or `uni` " + "distributed executor backend, but you choose " + f"`{parallel_config.distributed_executor_backend}`." + ) + speculative_config = self.create_speculative_config( target_model_config=model_config, target_parallel_config=parallel_config, @@ -1479,13 +1574,11 @@ def create_engine_config( max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, max_model_len=model_config.max_model_len, - cuda_graph_sizes=self.cuda_graph_sizes, num_lookahead_slots=num_lookahead_slots, enable_chunked_prefill=self.enable_chunked_prefill, disable_chunked_mm_input=self.disable_chunked_mm_input, is_multimodal_model=model_config.is_multimodal_model, is_encoder_decoder=model_config.is_encoder_decoder, - send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), policy=self.scheduling_policy, scheduler_cls=self.scheduler_cls, max_num_partial_prefills=self.max_num_partial_prefills, @@ -1503,7 +1596,6 @@ def create_engine_config( lora_config = ( LoRAConfig( - bias_enabled=self.enable_lora_bias, max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, default_mm_loras=self.default_mm_loras, @@ -1533,15 +1625,13 @@ def create_engine_config( if self.guided_decoding_backend is not None: so_config.guided_decoding_backend = self.guided_decoding_backend if self.guided_decoding_disable_fallback is not None: - so_config.guided_decoding_disable_fallback = ( - self.guided_decoding_disable_fallback - ) + so_config.disable_fallback = self.guided_decoding_disable_fallback if self.guided_decoding_disable_any_whitespace is not None: - so_config.guided_decoding_disable_any_whitespace = ( + so_config.disable_any_whitespace = ( self.guided_decoding_disable_any_whitespace ) if self.guided_decoding_disable_additional_properties is not None: - so_config.guided_decoding_disable_additional_properties = ( + so_config.disable_additional_properties = ( self.guided_decoding_disable_additional_properties ) @@ -1551,6 +1641,38 @@ def create_engine_config( collect_detailed_traces=self.collect_detailed_traces, ) + # Compilation config overrides + if self.cuda_graph_sizes is not None: + logger.warning( + "--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or " + "v1.0.0, whichever is soonest. Please use --cudagraph-capture-sizes " + "instead." + ) + if self.compilation_config.cudagraph_capture_sizes is not None: + raise ValueError( + "cuda_graph_sizes and compilation_config." + "cudagraph_capture_sizes are mutually exclusive" + ) + self.compilation_config.cudagraph_capture_sizes = self.cuda_graph_sizes + if self.cudagraph_capture_sizes is not None: + if self.compilation_config.cudagraph_capture_sizes is not None: + raise ValueError( + "cudagraph_capture_sizes and compilation_config." + "cudagraph_capture_sizes are mutually exclusive" + ) + self.compilation_config.cudagraph_capture_sizes = ( + self.cudagraph_capture_sizes + ) + if self.max_cudagraph_capture_size is not None: + if self.compilation_config.max_cudagraph_capture_size is not None: + raise ValueError( + "max_cudagraph_capture_size and compilation_config." + "max_cudagraph_capture_size are mutually exclusive" + ) + self.compilation_config.max_cudagraph_capture_size = ( + self.max_cudagraph_capture_size + ) + config = VllmConfig( model_config=model_config, cache_config=cache_config, @@ -1582,13 +1704,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: ) return False - # No Mamba or Encoder-Decoder so far. - if not model_config.is_v1_compatible: - _raise_or_fallback( - feature_name=model_config.architectures, recommend_to_remove=False - ) - return False - # No Concurrent Partial Prefills so far. if ( self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills @@ -1677,22 +1792,12 @@ def _set_default_args( ) -> None: """Set Default Arguments for V1 Engine.""" - # V1 always uses chunked prefills and prefix caching + # V1 uses chunked prefills and prefix caching by default # for non-pooling tasks. # For pooling tasks the default is False if model_config.runner_type != "pooling": self.enable_chunked_prefill = True - # TODO: When prefix caching supports prompt embeds inputs, this - # check can be removed. - if self.enable_prompt_embeds and self.enable_prefix_caching is not False: - logger.warning( - "--enable-prompt-embeds and --enable-prefix-caching " - "are not supported together in V1. Prefix caching has " - "been disabled." - ) - self.enable_prefix_caching = False - if self.enable_prefix_caching is None: # Disable prefix caching default for hybrid models # since the feature is still experimental. @@ -1718,11 +1823,6 @@ def _set_default_args( self.enable_prefix_caching = incremental_prefill_supported logger.info("(%s) prefix caching by default", action) - # V1 should use the new scheduler by default. - # Swap it only if this arg is set to the original V0 default - if self.scheduler_cls == EngineArgs.scheduler_cls: - self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler" - # When no user override, set the default values based on the usage # context. # Use different default values for different hardware. diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py deleted file mode 100644 index 45b798ed96cb..000000000000 --- a/vllm/engine/metrics.py +++ /dev/null @@ -1,688 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import time -from collections import Counter as CollectionsCounter -from typing import Optional, Union, cast - -import numpy as np -import prometheus_client - -from vllm.config import SupportsMetricsInfo, VllmConfig -from vllm.engine.metrics_types import StatLoggerBase, Stats -from vllm.executor.ray_utils import ray -from vllm.logger import init_logger - -if ray is not None: - from ray.util import metrics as ray_metrics -else: - ray_metrics = None - -logger = init_logger(__name__) - -prometheus_client.disable_created_metrics() - -# The begin-* and end* here are used by the documentation generator -# to extract the metrics definitions. - - -# --8<-- [start:metrics-definitions] -class Metrics: - """ - vLLM uses a multiprocessing-based frontend for the OpenAI server. - This means that we need to run prometheus_client in multiprocessing mode - See https://prometheus.github.io/client_python/multiprocess/ for more - details on limitations. - """ - - labelname_finish_reason = "finished_reason" - labelname_waiting_lora_adapters = "waiting_lora_adapters" - labelname_running_lora_adapters = "running_lora_adapters" - labelname_max_lora = "max_lora" - _gauge_cls = prometheus_client.Gauge - _counter_cls = prometheus_client.Counter - _histogram_cls = prometheus_client.Histogram - - def __init__(self, labelnames: list[str], vllm_config: VllmConfig): - # Unregister any existing vLLM collectors (for CI/CD) - self._unregister_vllm_metrics() - - max_model_len = vllm_config.model_config.max_model_len - - # Use this flag to hide metrics that were deprecated in - # a previous release and which will be removed future - self.show_hidden_metrics = vllm_config.observability_config.show_hidden_metrics - - # System stats - # Scheduler State - self.gauge_scheduler_running = self._gauge_cls( - name="vllm:num_requests_running", - documentation="Number of requests currently running on GPU.", - labelnames=labelnames, - multiprocess_mode="sum", - ) - self.gauge_scheduler_waiting = self._gauge_cls( - name="vllm:num_requests_waiting", - documentation="Number of requests waiting to be processed.", - labelnames=labelnames, - multiprocess_mode="sum", - ) - self.gauge_lora_info = self._gauge_cls( - name="vllm:lora_requests_info", - documentation="Running stats on lora requests.", - labelnames=[ - self.labelname_running_lora_adapters, - self.labelname_max_lora, - self.labelname_waiting_lora_adapters, - ], - multiprocess_mode="livemostrecent", - ) - - # KV Cache Usage in % - self.gauge_gpu_cache_usage = self._gauge_cls( - name="vllm:gpu_cache_usage_perc", - documentation="GPU KV-cache usage. 1 means 100 percent usage.", - labelnames=labelnames, - multiprocess_mode="sum", - ) - - # Iteration stats - self.counter_num_preemption = self._counter_cls( - name="vllm:num_preemptions_total", - documentation="Cumulative number of preemption from the engine.", - labelnames=labelnames, - ) - self.counter_prompt_tokens = self._counter_cls( - name="vllm:prompt_tokens_total", - documentation="Number of prefill tokens processed.", - labelnames=labelnames, - ) - self.counter_generation_tokens = self._counter_cls( - name="vllm:generation_tokens_total", - documentation="Number of generation tokens processed.", - labelnames=labelnames, - ) - self.histogram_iteration_tokens = self._histogram_cls( - name="vllm:iteration_tokens_total", - documentation="Histogram of number of tokens per engine_step.", - labelnames=labelnames, - buckets=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], - ) - self.histogram_time_to_first_token = self._histogram_cls( - name="vllm:time_to_first_token_seconds", - documentation="Histogram of time to first token in seconds.", - labelnames=labelnames, - buckets=[ - 0.001, - 0.005, - 0.01, - 0.02, - 0.04, - 0.06, - 0.08, - 0.1, - 0.25, - 0.5, - 0.75, - 1.0, - 2.5, - 5.0, - 7.5, - 10.0, - 20.0, - 40.0, - 80.0, - 160.0, - 640.0, - 2560.0, - ], - ) - # Deprecated in 0.11 - Renamed as vllm:inter_token_latency_seconds - # TODO: in 0.12, only enable if show_hidden_metrics=True - self.histogram_time_per_output_token = self._histogram_cls( - name="vllm:time_per_output_token_seconds", - documentation=( - "Histogram of time per output token in seconds." - "DEPRECATED: Use vllm:inter_token_latency_seconds instead." - ), - labelnames=labelnames, - buckets=[ - 0.01, - 0.025, - 0.05, - 0.075, - 0.1, - 0.15, - 0.2, - 0.3, - 0.4, - 0.5, - 0.75, - 1.0, - 2.5, - 5.0, - 7.5, - 10.0, - 20.0, - 40.0, - 80.0, - ], - ) - self.histogram_inter_token_latency = self._histogram_cls( - name="vllm:inter_token_latency_seconds", - documentation="Histogram of inter token latency in seconds.", - labelnames=labelnames, - buckets=[ - 0.01, - 0.025, - 0.05, - 0.075, - 0.1, - 0.15, - 0.2, - 0.3, - 0.4, - 0.5, - 0.75, - 1.0, - 2.5, - 5.0, - 7.5, - 10.0, - 20.0, - 40.0, - 80.0, - ], - ) - - # Request stats - # Latency - request_latency_buckets = [ - 0.3, - 0.5, - 0.8, - 1.0, - 1.5, - 2.0, - 2.5, - 5.0, - 10.0, - 15.0, - 20.0, - 30.0, - 40.0, - 50.0, - 60.0, - 120.0, - 240.0, - 480.0, - 960.0, - 1920.0, - 7680.0, - ] - self.histogram_e2e_time_request = self._histogram_cls( - name="vllm:e2e_request_latency_seconds", - documentation="Histogram of end to end request latency in seconds.", - labelnames=labelnames, - buckets=request_latency_buckets, - ) - self.histogram_queue_time_request = self._histogram_cls( - name="vllm:request_queue_time_seconds", - documentation="Histogram of time spent in WAITING phase for request.", - labelnames=labelnames, - buckets=request_latency_buckets, - ) - self.histogram_inference_time_request = self._histogram_cls( - name="vllm:request_inference_time_seconds", - documentation="Histogram of time spent in RUNNING phase for request.", - labelnames=labelnames, - buckets=request_latency_buckets, - ) - self.histogram_prefill_time_request = self._histogram_cls( - name="vllm:request_prefill_time_seconds", - documentation="Histogram of time spent in PREFILL phase for request.", - labelnames=labelnames, - buckets=request_latency_buckets, - ) - self.histogram_decode_time_request = self._histogram_cls( - name="vllm:request_decode_time_seconds", - documentation="Histogram of time spent in DECODE phase for request.", - labelnames=labelnames, - buckets=request_latency_buckets, - ) - - # Metadata - self.histogram_num_prompt_tokens_request = self._histogram_cls( - name="vllm:request_prompt_tokens", - documentation="Number of prefill tokens processed.", - labelnames=labelnames, - buckets=build_1_2_5_buckets(max_model_len), - ) - self.histogram_num_generation_tokens_request = self._histogram_cls( - name="vllm:request_generation_tokens", - documentation="Number of generation tokens processed.", - labelnames=labelnames, - buckets=build_1_2_5_buckets(max_model_len), - ) - self.histogram_max_num_generation_tokens_request = self._histogram_cls( - name="vllm:request_max_num_generation_tokens", - documentation="Histogram of maximum number of requested generation tokens.", - labelnames=labelnames, - buckets=build_1_2_5_buckets(max_model_len), - ) - self.histogram_n_request = self._histogram_cls( - name="vllm:request_params_n", - documentation="Histogram of the n request parameter.", - labelnames=labelnames, - buckets=[1, 2, 5, 10, 20], - ) - self.histogram_max_tokens_request = self._histogram_cls( - name="vllm:request_params_max_tokens", - documentation="Histogram of the max_tokens request parameter.", - labelnames=labelnames, - buckets=build_1_2_5_buckets(max_model_len), - ) - self.counter_request_success = self._counter_cls( - name="vllm:request_success_total", - documentation="Count of successfully processed requests.", - labelnames=labelnames + [Metrics.labelname_finish_reason], - ) - - # --8<-- [end:metrics-definitions] - - def _unregister_vllm_metrics(self) -> None: - for collector in list(prometheus_client.REGISTRY._collector_to_names): - if hasattr(collector, "_name") and "vllm" in collector._name: - prometheus_client.REGISTRY.unregister(collector) - - -class _RayGaugeWrapper: - """Wraps around ray.util.metrics.Gauge to provide same API as - prometheus_client.Gauge""" - - def __init__( - self, - name: str, - documentation: str = "", - labelnames: Optional[list[str]] = None, - multiprocess_mode: str = "", - ): - del multiprocess_mode - labelnames_tuple = tuple(labelnames) if labelnames else None - self._gauge = ray_metrics.Gauge( - name=name, description=documentation, tag_keys=labelnames_tuple - ) - - def labels(self, **labels): - self._gauge.set_default_tags(labels) - return self - - def set(self, value: Union[int, float]): - return self._gauge.set(value) - - def set_to_current_time(self): - # ray metrics doesn't have set_to_current time, https://docs.ray.io/en/latest/_modules/ray/util/metrics.html - return self._gauge.set(time.time()) - - -class _RayCounterWrapper: - """Wraps around ray.util.metrics.Counter to provide same API as - prometheus_client.Counter""" - - def __init__( - self, name: str, documentation: str = "", labelnames: Optional[list[str]] = None - ): - labelnames_tuple = tuple(labelnames) if labelnames else None - self._counter = ray_metrics.Counter( - name=name, description=documentation, tag_keys=labelnames_tuple - ) - - def labels(self, **labels): - self._counter.set_default_tags(labels) - return self - - def inc(self, value: Union[int, float] = 1.0): - if value == 0: - return - return self._counter.inc(value) - - -class _RayHistogramWrapper: - """Wraps around ray.util.metrics.Histogram to provide same API as - prometheus_client.Histogram""" - - def __init__( - self, - name: str, - documentation: str = "", - labelnames: Optional[list[str]] = None, - buckets: Optional[list[float]] = None, - ): - labelnames_tuple = tuple(labelnames) if labelnames else None - boundaries = buckets if buckets else [] - self._histogram = ray_metrics.Histogram( - name=name, - description=documentation, - tag_keys=labelnames_tuple, - boundaries=boundaries, - ) - - def labels(self, **labels): - self._histogram.set_default_tags(labels) - return self - - def observe(self, value: Union[int, float]): - return self._histogram.observe(value) - - -class RayMetrics(Metrics): - """ - RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics. - Provides the same metrics as Metrics but uses Ray's util.metrics library. - """ - - _gauge_cls: type[prometheus_client.Gauge] = cast( - type[prometheus_client.Gauge], _RayGaugeWrapper - ) - _counter_cls: type[prometheus_client.Counter] = cast( - type[prometheus_client.Counter], _RayCounterWrapper - ) - _histogram_cls: type[prometheus_client.Histogram] = cast( - type[prometheus_client.Histogram], _RayHistogramWrapper - ) - - def __init__(self, labelnames: list[str], vllm_config: VllmConfig): - if ray_metrics is None: - raise ImportError("RayMetrics requires Ray to be installed.") - super().__init__(labelnames, vllm_config) - - def _unregister_vllm_metrics(self) -> None: - # No-op on purpose - pass - - -def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]: - """ - Builds a list of buckets with increasing powers of 10 multiplied by - mantissa values until the value exceeds the specified maximum. - - """ - exponent = 0 - buckets: list[int] = [] - while True: - for m in mantissa_lst: - value = m * 10**exponent - if value <= max_value: - buckets.append(value) - else: - return buckets - exponent += 1 - - -def build_1_2_5_buckets(max_value: int) -> list[int]: - """ - Example: - >>> build_1_2_5_buckets(100) - [1, 2, 5, 10, 20, 50, 100] - """ - return build_buckets([1, 2, 5], max_value) - - -def build_1_2_3_5_8_buckets(max_value: int) -> list[int]: - """ - Example: - >>> build_1_2_3_5_8_buckets(100) - [1, 2, 3, 5, 8, 10, 20, 30, 50, 80, 100] - """ - return build_buckets([1, 2, 3, 5, 8], max_value) - - -def local_interval_elapsed(now: float, last_log: float, local_interval: float) -> bool: - elapsed_time = now - last_log - return elapsed_time > local_interval - - -def get_throughput(tracked_stats: list[int], now: float, last_log: float) -> float: - return float(np.sum(tracked_stats) / (now - last_log)) - - -class LoggingStatLogger(StatLoggerBase): - """LoggingStatLogger is used in LLMEngine to log to Stdout.""" - - def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None: - super().__init__(local_interval, vllm_config) - self.last_prompt_throughput: Optional[float] = None - self.last_generation_throughput: Optional[float] = None - - def log(self, stats: Stats) -> None: - """Called by LLMEngine. - Logs to Stdout every self.local_interval seconds.""" - - # Save tracked stats for token counters. - self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) - self.num_generation_tokens.append(stats.num_generation_tokens_iter) - - # Log locally every local_interval seconds. - if local_interval_elapsed(stats.now, self.last_local_log, self.local_interval): - # Compute summary metrics for tracked stats (and log them - # to prometheus if applicable). - prompt_throughput = get_throughput( - self.num_prompt_tokens, now=stats.now, last_log=self.last_local_log - ) - generation_throughput = get_throughput( - self.num_generation_tokens, now=stats.now, last_log=self.last_local_log - ) - - log_fn = logger.info - if not any( - ( - prompt_throughput, - generation_throughput, - self.last_prompt_throughput, - self.last_generation_throughput, - ) - ): - # Avoid log noise on an idle production system - log_fn = logger.debug - - log_fn( - "Avg prompt throughput: %.1f tokens/s, " - "Avg generation throughput: %.1f tokens/s, " - "Running: %d reqs, Swapped: %d reqs, " - "Pending: %d reqs, GPU KV cache usage: %.1f%%, " - "CPU KV cache usage: %.1f%%.", - prompt_throughput, - generation_throughput, - stats.num_running_sys, - stats.num_swapped_sys, - stats.num_waiting_sys, - stats.gpu_cache_usage_sys * 100, - stats.cpu_cache_usage_sys * 100, - ) - if ( - stats.cpu_prefix_cache_hit_rate >= 0 - or stats.gpu_prefix_cache_hit_rate >= 0 - ): - log_fn( - "Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%", - stats.gpu_prefix_cache_hit_rate * 100, - stats.cpu_prefix_cache_hit_rate * 100, - ) - - self._reset(stats, prompt_throughput, generation_throughput) - - def _reset(self, stats, prompt_throughput, generation_throughput) -> None: - # Reset tracked stats for next interval. - self.num_prompt_tokens = [] - self.num_generation_tokens = [] - self.last_local_log = stats.now - self.last_prompt_throughput = prompt_throughput - self.last_generation_throughput = generation_throughput - - def info(self, type: str, obj: SupportsMetricsInfo) -> None: - raise NotImplementedError - - -class PrometheusStatLogger(StatLoggerBase): - """PrometheusStatLogger is used LLMEngine to log to Prometheus.""" - - _metrics_cls = Metrics - _gauge_cls = prometheus_client.Gauge - - def __init__( - self, local_interval: float, labels: dict[str, str], vllm_config: VllmConfig - ) -> None: - super().__init__(local_interval, vllm_config) - # Prometheus metrics - self.labels = labels - self.metrics = self._metrics_cls( - labelnames=list(labels.keys()), vllm_config=vllm_config - ) - - def _log_gauge(self, gauge, data: Union[int, float]) -> None: - # Convenience function for logging to gauge. - gauge.labels(**self.labels).set(data) - - def _log_counter(self, counter, data: Union[int, float]) -> None: - # Convenience function for logging to counter. - # Prevent ValueError from negative increment - if data < 0: - logger.warning("Skipping negative increment of %g to %s", data, counter) - return - counter.labels(**self.labels).inc(data) - - def _log_counter_labels( - self, counter, data: CollectionsCounter, label_key: str - ) -> None: - # Convenience function for collection counter of labels. - for label, count in data.items(): - counter.labels(**{**self.labels, label_key: label}).inc(count) - - def _log_histogram(self, histogram, data: Union[list[int], list[float]]) -> None: - # Convenience function for logging list to histogram. - for datum in data: - histogram.labels(**self.labels).observe(datum) - - def _log_gauge_string(self, gauge, data: dict[str, str]) -> None: - gauge.labels(**data).set_to_current_time() - - def _log_prometheus(self, stats: Stats) -> None: - # System state data - self._log_gauge(self.metrics.gauge_scheduler_running, stats.num_running_sys) - self._log_gauge(self.metrics.gauge_scheduler_waiting, stats.num_waiting_sys) - self._log_gauge(self.metrics.gauge_gpu_cache_usage, stats.gpu_cache_usage_sys) - # Including max-lora in metric, in future this property of lora - # config maybe extended to be dynamic. - lora_info = { - self.metrics.labelname_running_lora_adapters: ",".join( - stats.running_lora_adapters - ), - self.metrics.labelname_waiting_lora_adapters: ",".join( - stats.waiting_lora_adapters - ), - self.metrics.labelname_max_lora: stats.max_lora, - } - self._log_gauge_string(self.metrics.gauge_lora_info, lora_info) - # Iteration level data - self._log_counter( - self.metrics.counter_num_preemption, stats.num_preemption_iter - ) - self._log_counter( - self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter - ) - self._log_counter( - self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter - ) - self._log_histogram( - self.metrics.histogram_iteration_tokens, [stats.num_tokens_iter] - ) - self._log_histogram( - self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter - ) - self._log_histogram( - self.metrics.histogram_time_per_output_token, - stats.inter_token_latencies_iter, - ) - self._log_histogram( - self.metrics.histogram_inter_token_latency, stats.inter_token_latencies_iter - ) - - # Request level data - # Latency - self._log_histogram( - self.metrics.histogram_e2e_time_request, stats.time_e2e_requests - ) - self._log_histogram( - self.metrics.histogram_queue_time_request, stats.time_queue_requests - ) - self._log_histogram( - self.metrics.histogram_inference_time_request, stats.time_inference_requests - ) - self._log_histogram( - self.metrics.histogram_prefill_time_request, stats.time_prefill_requests - ) - self._log_histogram( - self.metrics.histogram_decode_time_request, stats.time_decode_requests - ) - # Metadata - finished_reason_counter = CollectionsCounter(stats.finished_reason_requests) - self._log_counter_labels( - self.metrics.counter_request_success, - finished_reason_counter, - Metrics.labelname_finish_reason, - ) - self._log_histogram( - self.metrics.histogram_num_prompt_tokens_request, - stats.num_prompt_tokens_requests, - ) - self._log_histogram( - self.metrics.histogram_num_generation_tokens_request, - stats.num_generation_tokens_requests, - ) - self._log_histogram(self.metrics.histogram_n_request, stats.n_requests) - self._log_histogram( - self.metrics.histogram_max_num_generation_tokens_request, - stats.max_num_generation_tokens_requests, - ) - self._log_histogram( - self.metrics.histogram_max_tokens_request, stats.max_tokens_requests - ) - - def log(self, stats: Stats): - """Logs to prometheus and tracked stats every iteration.""" - # Log to prometheus. - self._log_prometheus(stats) - - # Save tracked stats for token counters. - self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) - self.num_generation_tokens.append(stats.num_generation_tokens_iter) - - # Log locally every local_interval seconds. - if local_interval_elapsed(stats.now, self.last_local_log, self.local_interval): - # Reset tracked stats for next interval. - self.num_prompt_tokens = [] - self.num_generation_tokens = [] - self.last_local_log = stats.now - - def info(self, type: str, obj: SupportsMetricsInfo) -> None: - # Info type metrics are syntactic sugar for a gauge permanently set to 1 - # Since prometheus multiprocessing mode does not support Info, emulate - # info here with a gauge. - if type == "cache_config": - metrics_info = obj.metrics_info() - info_gauge = self._gauge_cls( - name="vllm:cache_config_info", - documentation="Information of the LLMEngine CacheConfig", - labelnames=metrics_info.keys(), - multiprocess_mode="mostrecent", - ) - info_gauge.labels(**metrics_info).set(1) - - -class RayPrometheusStatLogger(PrometheusStatLogger): - """RayPrometheusStatLogger uses Ray metrics instead.""" - - _metrics_cls = RayMetrics - - def info(self, type: str, obj: SupportsMetricsInfo) -> None: - return None diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py deleted file mode 100644 index ac796f4e1c75..000000000000 --- a/vllm/engine/metrics_types.py +++ /dev/null @@ -1,84 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -These types are defined in this file to avoid importing vllm.engine.metrics -and therefore importing prometheus_client. - -This is required due to usage of Prometheus multiprocess mode to enable -metrics after splitting out the uvicorn process from the engine process. - -Prometheus multiprocess mode requires setting PROMETHEUS_MULTIPROC_DIR -before prometheus_client is imported. Typically, this is done by setting -the env variable before launch, but since we are a library, we need to -do this in Python code and lazily import prometheus_client. -""" - -import time -from abc import ABC, abstractmethod -from dataclasses import dataclass - -from vllm.config import SupportsMetricsInfo, VllmConfig - - -@dataclass -class Stats: - """Created by LLMEngine for use by StatLogger.""" - - now: float - - # System stats (should have _sys suffix) - # Scheduler State - num_running_sys: int - num_waiting_sys: int - num_swapped_sys: int - # KV Cache Usage in % - gpu_cache_usage_sys: float - cpu_cache_usage_sys: float - # Prefix caching block hit rate - cpu_prefix_cache_hit_rate: float - gpu_prefix_cache_hit_rate: float - - # Iteration stats (should have _iter suffix) - num_prompt_tokens_iter: int - num_generation_tokens_iter: int - num_tokens_iter: int - time_to_first_tokens_iter: list[float] - inter_token_latencies_iter: list[float] - num_preemption_iter: int - - # Request stats (should have _requests suffix) - # Latency - time_e2e_requests: list[float] - time_queue_requests: list[float] - time_inference_requests: list[float] - time_prefill_requests: list[float] - time_decode_requests: list[float] - # Metadata - num_prompt_tokens_requests: list[int] - num_generation_tokens_requests: list[int] - n_requests: list[int] - max_num_generation_tokens_requests: list[int] - max_tokens_requests: list[int] - finished_reason_requests: list[str] - waiting_lora_adapters: list[str] - running_lora_adapters: list[str] - max_lora: str - - -class StatLoggerBase(ABC): - """Base class for StatLogger.""" - - def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None: - # Tracked stats over current local logging interval. - self.num_prompt_tokens: list[int] = [] - self.num_generation_tokens: list[int] = [] - self.last_local_log = time.time() - self.local_interval = local_interval - - @abstractmethod - def log(self, stats: Stats) -> None: - raise NotImplementedError - - @abstractmethod - def info(self, type: str, obj: SupportsMetricsInfo) -> None: - raise NotImplementedError diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index feb2e841c83a..20b8eb57f743 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -1,26 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio from abc import ABC, abstractmethod from collections.abc import AsyncGenerator, Iterable, Mapping -from typing import Any, Optional, Union +from typing import Any -from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import ModelConfig, VllmConfig -from vllm.inputs.data import PromptType, TokensPrompt -from vllm.inputs.parse import is_explicit_encoder_decoder_prompt -from vllm.inputs.preprocess import InputPreprocessor +from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput -from vllm.plugins.io_processors.interface import IOProcessor +from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.plugins.io_processors import IOProcessor from vllm.pooling_params import PoolingParams -from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import Device, collect_from_async_generator, random_uuid +from vllm.utils import Device from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.processor import Processor logger = init_logger(__name__) @@ -28,6 +25,11 @@ class EngineClient(ABC): """Protocol class for Clients to Engine""" + vllm_config: VllmConfig + model_config: ModelConfig + processor: Processor + io_processor: IOProcessor | None + @property @abstractmethod def is_running(self) -> bool: ... @@ -47,210 +49,36 @@ def dead_error(self) -> BaseException: ... @abstractmethod def generate( self, - prompt: Union[EngineCoreRequest, PromptType], + prompt: EngineCoreRequest | PromptType, sampling_params: SamplingParams, request_id: str, *, - prompt_text: Optional[str] = None, - lora_request: Optional[LoRARequest] = None, - tokenization_kwargs: Optional[dict[str, Any]] = None, - trace_headers: Optional[Mapping[str, str]] = None, + prompt_text: str | None = None, + lora_request: LoRARequest | None = None, + tokenization_kwargs: dict[str, Any] | None = None, + trace_headers: Mapping[str, str] | None = None, priority: int = 0, - data_parallel_rank: Optional[int] = None, + data_parallel_rank: int | None = None, ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request.""" ... - async def beam_search( - self, - prompt: PromptType, - request_id: str, - params: BeamSearchParams, - lora_request: Optional[LoRARequest] = None, - ) -> AsyncGenerator[RequestOutput, None]: - beam_width = params.beam_width - max_tokens = params.max_tokens - ignore_eos = params.ignore_eos - temperature = params.temperature - length_penalty = params.length_penalty - include_stop_str_in_output = params.include_stop_str_in_output - - preprocessor = await self.get_input_preprocessor() - tokenizer = preprocessor.get_tokenizer() - eos_token_id = tokenizer.eos_token_id - - if is_explicit_encoder_decoder_prompt(prompt): - raise NotImplementedError - else: - processed_inputs = preprocessor._prompt_to_llm_inputs(prompt) - - if processed_inputs["type"] == "embeds": - raise NotImplementedError - - # This is a workaround to fix multimodal beam search; this is a - # bandaid fix for 2 small problems: - # 1. Multi_modal_data on the processed_inputs currently resolves to - # `None`. - # 2. preprocessing above expands the multimodal placeholders. However, - # this happens again in generation, so the double expansion causes - # a mismatch. - # TODO - would be ideal to handle this more gracefully. - if isinstance(prompt, str): - prompt_text = prompt - prompt_token_ids = [] - multi_modal_data = None - else: - prompt_text = prompt.get("prompt") - prompt_token_ids = prompt.get("prompt_token_ids", []) - multi_modal_data = prompt.get("multi_modal_data") - - mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs") - - tokenized_length = len(prompt_token_ids) - - sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty) - - beam_search_params = SamplingParams( - logprobs=2 * beam_width, - max_tokens=1, - temperature=temperature, - ) - all_beams = [ - BeamSearchSequence( - tokens=prompt_token_ids, - cum_logprob=0, - logprobs=[], - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs, - lora_request=lora_request, - ) - ] - completed = [] - - for _ in range(max_tokens): - prompts_batch, lora_req_batch = zip( - *[ - ( - TokensPrompt( - prompt_token_ids=beam.tokens, - multi_modal_data=beam.multi_modal_data, - mm_processor_kwargs=beam.mm_processor_kwargs, - ), - beam.lora_request, - ) - for beam in all_beams - ] - ) - - tasks = [] - - request_id = f"beam_search-{random_uuid()}" - for i, (individual_prompt, lora_req) in enumerate( - zip(prompts_batch, lora_req_batch) - ): - request_id_item = f"{request_id}-{i}" - task = asyncio.create_task( - collect_from_async_generator( - self.generate( - individual_prompt, - beam_search_params, - request_id_item, - lora_request=lora_req, - ) - ) - ) - tasks.append(task) - - output = await asyncio.gather(*tasks) - - output = [x[0] for x in output] - - new_beams = [] - for i, current_beam in enumerate(all_beams): - result = output[i] - - if result.outputs[0].logprobs is not None: - logprobs = result.outputs[0].logprobs[0] - for token_id, logprob_obj in logprobs.items(): - if token_id == eos_token_id and not ignore_eos: - completed.append( - BeamSearchSequence( - tokens=current_beam.tokens + [token_id] - if include_stop_str_in_output - else current_beam.tokens, - logprobs=current_beam.logprobs + [logprobs], - cum_logprob=current_beam.cum_logprob - + logprob_obj.logprob, - finish_reason="stop", - stop_reason=eos_token_id, - ) - ) - else: - new_beams.append( - BeamSearchSequence( - tokens=current_beam.tokens + [token_id], - logprobs=current_beam.logprobs + [logprobs], - lora_request=current_beam.lora_request, - cum_logprob=current_beam.cum_logprob - + logprob_obj.logprob, - multi_modal_data=current_beam.multi_modal_data, - mm_processor_kwargs=current_beam.mm_processor_kwargs, - ) - ) - - sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) - all_beams = sorted_beams[:beam_width] - - completed.extend(all_beams) - sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) - best_beams = sorted_completed[:beam_width] - - for beam in best_beams: - if beam.tokens[-1] == eos_token_id and not ignore_eos: - # Skip the eos token in the text. - tokens = beam.tokens[tokenized_length:-1] - else: - tokens = beam.tokens[tokenized_length:] - beam.text = tokenizer.decode(tokens) - - yield RequestOutput( - request_id=request_id, - prompt=prompt_text, - outputs=[ - CompletionOutput( - text=beam.text, - cumulative_logprob=beam.cum_logprob, - token_ids=beam.tokens[tokenized_length:], - index=i, - logprobs=beam.logprobs, - finish_reason=beam.finish_reason - if beam.finish_reason is not None - else "length", - stop_reason=beam.stop_reason, - ) - for (i, beam) in enumerate(best_beams) - ], - finished=True, - prompt_token_ids=prompt_token_ids, - prompt_logprobs=None, - ) - @abstractmethod def encode( self, prompt: PromptType, pooling_params: PoolingParams, request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, + lora_request: LoRARequest | None = None, + trace_headers: Mapping[str, str] | None = None, priority: int = 0, - tokenization_kwargs: Optional[dict[str, Any]] = None, + tokenization_kwargs: dict[str, Any] | None = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from a pooling model.""" ... @abstractmethod - async def abort(self, request_id: Union[str, Iterable[str]]) -> None: + async def abort(self, request_id: str | Iterable[str]) -> None: """Abort a request. Args: @@ -259,29 +87,11 @@ async def abort(self, request_id: Union[str, Iterable[str]]) -> None: """ ... - @abstractmethod - async def get_vllm_config(self) -> VllmConfig: - """Get the vllm configuration of the vLLM engine.""" - ... - - @abstractmethod - async def get_model_config(self) -> ModelConfig: - """Get the model configuration of the vLLM engine.""" - ... - - @abstractmethod - async def get_input_preprocessor(self) -> InputPreprocessor: - """Get the input processor of the vLLM engine.""" - ... - @abstractmethod async def get_tokenizer(self) -> AnyTokenizer: """Get the tokenizer""" ... - async def get_io_processor(self) -> IOProcessor: - raise NotImplementedError - @abstractmethod async def is_tracing_enabled(self) -> bool: ... @@ -300,7 +110,7 @@ async def start_profile(self) -> None: @abstractmethod async def stop_profile(self) -> None: - """Start profiling the engine""" + """Stop profiling the engine""" ... @abstractmethod @@ -309,7 +119,7 @@ async def reset_mm_cache(self) -> None: ... @abstractmethod - async def reset_prefix_cache(self, device: Optional[Device] = None) -> None: + async def reset_prefix_cache(self, device: Device | None = None) -> None: """Reset the prefix cache""" ... @@ -319,7 +129,7 @@ async def sleep(self, level: int = 1) -> None: ... @abstractmethod - async def wake_up(self, tags: Optional[list[str]] = None) -> None: + async def wake_up(self, tags: list[str] | None = None) -> None: """Wake up the engine""" ... @@ -342,9 +152,9 @@ async def scale_elastic_ep( async def collective_rpc( self, method: str, - timeout: Optional[float] = None, + timeout: float | None = None, args: tuple = (), - kwargs: Optional[dict] = None, + kwargs: dict | None = None, ): """Perform a collective RPC call to the given path.""" raise NotImplementedError diff --git a/vllm/entrypoints/anthropic/__init__.py b/vllm/entrypoints/anthropic/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/entrypoints/anthropic/api_server.py b/vllm/entrypoints/anthropic/api_server.py new file mode 100644 index 000000000000..249a7ee0121a --- /dev/null +++ b/vllm/entrypoints/anthropic/api_server.py @@ -0,0 +1,300 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from: +# https://github.com/vllm/vllm/entrypoints/openai/api_server.py + +import asyncio +import signal +import tempfile +from argparse import Namespace +from http import HTTPStatus + +import uvloop +from fastapi import APIRouter, Depends, FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, Response, StreamingResponse +from starlette.datastructures import State + +import vllm.envs as envs +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.anthropic.protocol import ( + AnthropicErrorResponse, + AnthropicMessagesRequest, + AnthropicMessagesResponse, +) +from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client, + create_server_socket, + lifespan, + load_log_config, + validate_api_server_args, + validate_json_request, +) +from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args +from vllm.entrypoints.openai.protocol import ErrorResponse +from vllm.entrypoints.openai.serving_models import ( + BaseModelPath, + OpenAIServingModels, +) + +# +# yapf: enable +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.utils import ( + cli_env_setup, + load_aware_call, + process_chat_template, + process_lora_modules, + with_cancellation, +) +from vllm.logger import init_logger +from vllm.utils import FlexibleArgumentParser, set_ulimit +from vllm.utils.network_utils import is_valid_ipv6_address +from vllm.version import __version__ as VLLM_VERSION + +prometheus_multiproc_dir: tempfile.TemporaryDirectory + +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) +logger = init_logger("vllm.entrypoints.anthropic.api_server") + +_running_tasks: set[asyncio.Task] = set() + +router = APIRouter() + + +def messages(request: Request) -> AnthropicServingMessages: + return request.app.state.anthropic_serving_messages + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +@router.get("/health", response_class=Response) +async def health(raw_request: Request) -> Response: + """Health check.""" + await engine_client(raw_request).check_health() + return Response(status_code=200) + + +@router.get("/ping", response_class=Response) +@router.post("/ping", response_class=Response) +async def ping(raw_request: Request) -> Response: + """Ping check. Endpoint required for SageMaker""" + return await health(raw_request) + + +@router.post( + "/v1/messages", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse}, + }, +) +@with_cancellation +@load_aware_call +async def create_messages(request: AnthropicMessagesRequest, raw_request: Request): + handler = messages(raw_request) + if handler is None: + return messages(raw_request).create_error_response( + message="The model does not support Messages API" + ) + + generator = await handler.create_messages(request, raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump()) + + elif isinstance(generator, AnthropicMessagesResponse): + logger.debug( + "Anthropic Messages Response: %s", generator.model_dump(exclude_none=True) + ) + return JSONResponse(content=generator.model_dump(exclude_none=True)) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +async def init_app_state( + engine_client: EngineClient, + state: State, + args: Namespace, +) -> None: + vllm_config = engine_client.vllm_config + + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) + + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) for name in served_model_names + ] + + state.engine_client = engine_client + state.log_stats = not args.disable_log_stats + state.vllm_config = vllm_config + model_config = vllm_config.model_config + + default_mm_loras = ( + vllm_config.lora_config.default_mm_loras + if vllm_config.lora_config is not None + else {} + ) + lora_modules = process_lora_modules(args.lora_modules, default_mm_loras) + + resolved_chat_template = await process_chat_template( + args.chat_template, engine_client, model_config + ) + + state.openai_serving_models = OpenAIServingModels( + engine_client=engine_client, + base_model_paths=base_model_paths, + lora_modules=lora_modules, + ) + await state.openai_serving_models.init_static_loras() + state.anthropic_serving_messages = AnthropicServingMessages( + engine_client, + state.openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + reasoning_parser=args.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + ) + + +def setup_server(args): + """Validate API server args, set up signal handler, create socket + ready to serve.""" + + logger.info("vLLM API server version %s", VLLM_VERSION) + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + validate_api_server_args(args) + + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + sock_addr = (args.host or "", args.port) + sock = create_server_socket(sock_addr) + + # workaround to avoid footguns where uvicorn drops requests with too + # many concurrent requests active + set_ulimit() + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + addr, port = sock_addr + is_ssl = args.ssl_keyfile and args.ssl_certfile + host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0" + listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" + + return listen_address, sock + + +async def run_server(args, **uvicorn_kwargs) -> None: + """Run a single-worker API server.""" + listen_address, sock = setup_server(args) + await run_server_worker(listen_address, sock, args, **uvicorn_kwargs) + + +def build_app(args: Namespace) -> FastAPI: + app = FastAPI(lifespan=lifespan) + app.include_router(router) + app.root_path = args.root_path + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + + return app + + +async def run_server_worker( + listen_address, sock, args, client_config=None, **uvicorn_kwargs +) -> None: + """Run a single API server worker.""" + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + server_index = client_config.get("client_index", 0) if client_config else 0 + + # Load logging config for uvicorn if specified + log_config = load_log_config(args.log_config_file) + if log_config is not None: + uvicorn_kwargs["log_config"] = log_config + + async with build_async_engine_client( + args, + client_config=client_config, + ) as engine_client: + app = build_app(args) + + await init_app_state(engine_client, app.state, args) + + logger.info("Starting vLLM API server %d on %s", server_index, listen_address) + shutdown_task = await serve_http( + app, + sock=sock, + enable_ssl_refresh=args.enable_ssl_refresh, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + # NOTE: When the 'disable_uvicorn_access_log' value is True, + # no access log will be output. + access_log=not args.disable_uvicorn_access_log, + timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + # NB: Await server shutdown only after the backend context is exited + try: + await shutdown_task + finally: + sock.close() + + +if __name__ == "__main__": + # NOTE(simon): + # This section should be in sync with vllm/entrypoints/cli/main.py for CLI + # entrypoints. + cli_env_setup() + parser = FlexibleArgumentParser( + description="vLLM Anthropic-Compatible RESTful API server." + ) + parser = make_arg_parser(parser) + args = parser.parse_args() + validate_parsed_serve_args(args) + + uvloop.run(run_server(args)) diff --git a/vllm/entrypoints/anthropic/protocol.py b/vllm/entrypoints/anthropic/protocol.py new file mode 100644 index 000000000000..626ca7472ae6 --- /dev/null +++ b/vllm/entrypoints/anthropic/protocol.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Pydantic models for Anthropic API protocol""" + +import time +from typing import Any, Literal, Optional + +from pydantic import BaseModel, field_validator + + +class AnthropicError(BaseModel): + """Error structure for Anthropic API""" + + type: str + message: str + + +class AnthropicErrorResponse(BaseModel): + """Error response structure for Anthropic API""" + + type: Literal["error"] = "error" + error: AnthropicError + + +class AnthropicUsage(BaseModel): + """Token usage information""" + + input_tokens: int + output_tokens: int + cache_creation_input_tokens: int | None = None + cache_read_input_tokens: int | None = None + + +class AnthropicContentBlock(BaseModel): + """Content block in message""" + + type: Literal["text", "image", "tool_use", "tool_result"] + text: str | None = None + # For image content + source: dict[str, Any] | None = None + # For tool use/result + id: str | None = None + name: str | None = None + input: dict[str, Any] | None = None + content: str | list[dict[str, Any]] | None = None + is_error: bool | None = None + + +class AnthropicMessage(BaseModel): + """Message structure""" + + role: Literal["user", "assistant"] + content: str | list[AnthropicContentBlock] + + +class AnthropicTool(BaseModel): + """Tool definition""" + + name: str + description: str | None = None + input_schema: dict[str, Any] + + @field_validator("input_schema") + @classmethod + def validate_input_schema(cls, v): + if not isinstance(v, dict): + raise ValueError("input_schema must be a dictionary") + if "type" not in v: + v["type"] = "object" # Default to object type + return v + + +class AnthropicToolChoice(BaseModel): + """Tool Choice definition""" + + type: Literal["auto", "any", "tool"] + name: str | None = None + + +class AnthropicMessagesRequest(BaseModel): + """Anthropic Messages API request""" + + model: str + messages: list[AnthropicMessage] + max_tokens: int + metadata: dict[str, Any] | None = None + stop_sequences: list[str] | None = None + stream: bool | None = False + system: str | list[AnthropicContentBlock] | None = None + temperature: float | None = None + tool_choice: AnthropicToolChoice | None = None + tools: list[AnthropicTool] | None = None + top_k: int | None = None + top_p: float | None = None + + @field_validator("model") + @classmethod + def validate_model(cls, v): + if not v: + raise ValueError("Model is required") + return v + + @field_validator("max_tokens") + @classmethod + def validate_max_tokens(cls, v): + if v <= 0: + raise ValueError("max_tokens must be positive") + return v + + +class AnthropicDelta(BaseModel): + """Delta for streaming responses""" + + type: Literal["text_delta", "input_json_delta"] | None = None + text: str | None = None + partial_json: str | None = None + + # Message delta + stop_reason: ( + Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"] | None + ) = None + stop_sequence: str | None = None + + +class AnthropicStreamEvent(BaseModel): + """Streaming event""" + + type: Literal[ + "message_start", + "message_delta", + "message_stop", + "content_block_start", + "content_block_delta", + "content_block_stop", + "ping", + "error", + ] + message: Optional["AnthropicMessagesResponse"] = None + delta: AnthropicDelta | None = None + content_block: AnthropicContentBlock | None = None + index: int | None = None + error: AnthropicError | None = None + usage: AnthropicUsage | None = None + + +class AnthropicMessagesResponse(BaseModel): + """Anthropic Messages API response""" + + id: str + type: Literal["message"] = "message" + role: Literal["assistant"] = "assistant" + content: list[AnthropicContentBlock] + model: str + stop_reason: ( + Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"] | None + ) = None + stop_sequence: str | None = None + usage: AnthropicUsage | None = None + + def model_post_init(self, __context): + if not self.id: + self.id = f"msg_{int(time.time() * 1000)}" diff --git a/vllm/entrypoints/anthropic/serving_messages.py b/vllm/entrypoints/anthropic/serving_messages.py new file mode 100644 index 000000000000..11c96adf332f --- /dev/null +++ b/vllm/entrypoints/anthropic/serving_messages.py @@ -0,0 +1,458 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from +# https://github.com/vllm/vllm/entrypoints/openai/serving_chat.py + +"""Anthropic Messages API serving handler""" + +import json +import logging +import time +from collections.abc import AsyncGenerator +from typing import Any + +from fastapi import Request + +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.anthropic.protocol import ( + AnthropicContentBlock, + AnthropicDelta, + AnthropicError, + AnthropicMessagesRequest, + AnthropicMessagesResponse, + AnthropicStreamEvent, + AnthropicUsage, +) +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import ( + ChatCompletionNamedToolChoiceParam, + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionStreamResponse, + ChatCompletionToolsParam, + ErrorResponse, + StreamOptions, +) +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_models import OpenAIServingModels + +logger = logging.getLogger(__name__) + + +def wrap_data_with_event(data: str, event: str): + return f"event: {event}\ndata: {data}\n\n" + + +class AnthropicServingMessages(OpenAIServingChat): + """Handler for Anthropic Messages API requests""" + + def __init__( + self, + engine_client: EngineClient, + models: OpenAIServingModels, + response_role: str, + *, + request_logger: RequestLogger | None, + chat_template: str | None, + chat_template_content_format: ChatTemplateContentFormatOption, + return_tokens_as_token_ids: bool = False, + reasoning_parser: str = "", + enable_auto_tools: bool = False, + tool_parser: str | None = None, + enable_prompt_tokens_details: bool = False, + enable_force_include_usage: bool = False, + ): + super().__init__( + engine_client=engine_client, + models=models, + response_role=response_role, + request_logger=request_logger, + chat_template=chat_template, + chat_template_content_format=chat_template_content_format, + return_tokens_as_token_ids=return_tokens_as_token_ids, + reasoning_parser=reasoning_parser, + enable_auto_tools=enable_auto_tools, + tool_parser=tool_parser, + enable_prompt_tokens_details=enable_prompt_tokens_details, + enable_force_include_usage=enable_force_include_usage, + ) + self.stop_reason_map = { + "stop": "end_turn", + "length": "max_tokens", + "tool_calls": "tool_use", + } + + def _convert_anthropic_to_openai_request( + self, anthropic_request: AnthropicMessagesRequest + ) -> ChatCompletionRequest: + """Convert Anthropic message format to OpenAI format""" + openai_messages = [] + + # Add system message if provided + if anthropic_request.system: + if isinstance(anthropic_request.system, str): + openai_messages.append( + {"role": "system", "content": anthropic_request.system} + ) + else: + system_prompt = "" + for block in anthropic_request.system: + if block.type == "text" and block.text: + system_prompt += block.text + openai_messages.append({"role": "system", "content": system_prompt}) + + for msg in anthropic_request.messages: + openai_msg: dict[str, Any] = {"role": msg.role} # type: ignore + if isinstance(msg.content, str): + openai_msg["content"] = msg.content + else: + # Handle complex content blocks + content_parts: list[dict[str, Any]] = [] + tool_calls: list[dict[str, Any]] = [] + + for block in msg.content: + if block.type == "text" and block.text: + content_parts.append({"type": "text", "text": block.text}) + elif block.type == "image" and block.source: + content_parts.append( + { + "type": "image_url", + "image_url": {"url": block.source.get("data", "")}, + } + ) + elif block.type == "tool_use": + # Convert tool use to function call format + tool_call = { + "id": block.id or f"call_{int(time.time())}", + "type": "function", + "function": { + "name": block.name or "", + "arguments": json.dumps(block.input or {}), + }, + } + tool_calls.append(tool_call) + elif block.type == "tool_result": + if msg.role == "user": + openai_messages.append( + { + "role": "tool", + "tool_call_id": block.id or "", + "content": str(block.content) + if block.content + else "", + } + ) + else: + # Assistant tool result becomes regular text + tool_result_text = ( + str(block.content) if block.content else "" + ) + content_parts.append( + { + "type": "text", + "text": f"Tool result: {tool_result_text}", + } + ) + + # Add tool calls to the message if any + if tool_calls: + openai_msg["tool_calls"] = tool_calls # type: ignore + + # Add content parts if any + if content_parts: + if len(content_parts) == 1 and content_parts[0]["type"] == "text": + openai_msg["content"] = content_parts[0]["text"] + else: + openai_msg["content"] = content_parts # type: ignore + elif not tool_calls: + continue + + openai_messages.append(openai_msg) + + req = ChatCompletionRequest( + model=anthropic_request.model, + messages=openai_messages, + max_tokens=anthropic_request.max_tokens, + max_completion_tokens=anthropic_request.max_tokens, + stop=anthropic_request.stop_sequences, + temperature=anthropic_request.temperature, + top_p=anthropic_request.top_p, + top_k=anthropic_request.top_k, + ) + + if anthropic_request.stream: + req.stream = anthropic_request.stream + req.stream_options = StreamOptions.validate({"include_usage": True}) + + if anthropic_request.tool_choice is None: + req.tool_choice = None + elif anthropic_request.tool_choice.type == "auto": + req.tool_choice = "auto" + elif anthropic_request.tool_choice.type == "any": + req.tool_choice = "required" + elif anthropic_request.tool_choice.type == "tool": + req.tool_choice = ChatCompletionNamedToolChoiceParam.model_validate( + { + "type": "function", + "function": {"name": anthropic_request.tool_choice.name}, + } + ) + + tools = [] + if anthropic_request.tools is None: + return req + for tool in anthropic_request.tools: + tools.append( + ChatCompletionToolsParam.model_validate( + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.input_schema, + }, + } + ) + ) + if req.tool_choice is None: + req.tool_choice = "auto" + req.tools = tools + return req + + async def create_messages( + self, + request: AnthropicMessagesRequest, + raw_request: Request | None = None, + ) -> AsyncGenerator[str, None] | AnthropicMessagesResponse | ErrorResponse: + """ + Messages API similar to Anthropic's API. + + See https://docs.anthropic.com/en/api/messages + for the API specification. This API mimics the Anthropic messages API. + """ + logger.debug("Received messages request %s", request.model_dump_json()) + chat_req = self._convert_anthropic_to_openai_request(request) + logger.debug("Convert to OpenAI request %s", request.model_dump_json()) + generator = await self.create_chat_completion(chat_req, raw_request) + + if isinstance(generator, ErrorResponse): + return generator + + elif isinstance(generator, ChatCompletionResponse): + return self.messages_full_converter(generator) + + return self.message_stream_converter(generator) + + def messages_full_converter( + self, + generator: ChatCompletionResponse, + ) -> AnthropicMessagesResponse: + result = AnthropicMessagesResponse( + id=generator.id, + content=[], + model=generator.model, + usage=AnthropicUsage( + input_tokens=generator.usage.prompt_tokens, + output_tokens=generator.usage.completion_tokens, + ), + ) + if generator.choices[0].finish_reason == "stop": + result.stop_reason = "end_turn" + elif generator.choices[0].finish_reason == "length": + result.stop_reason = "max_tokens" + elif generator.choices[0].finish_reason == "tool_calls": + result.stop_reason = "tool_use" + + content: list[AnthropicContentBlock] = [ + AnthropicContentBlock( + type="text", + text=generator.choices[0].message.content + if generator.choices[0].message.content + else "", + ) + ] + + for tool_call in generator.choices[0].message.tool_calls: + anthropic_tool_call = AnthropicContentBlock( + type="tool_use", + id=tool_call.id, + name=tool_call.function.name, + input=json.loads(tool_call.function.arguments), + ) + content += [anthropic_tool_call] + + result.content = content + + return result + + async def message_stream_converter( + self, + generator: AsyncGenerator[str, None], + ) -> AsyncGenerator[str, None]: + try: + first_item = True + finish_reason = None + content_block_index = 0 + content_block_started = False + + async for item in generator: + if item.startswith("data:"): + data_str = item[5:].strip().rstrip("\n") + if data_str == "[DONE]": + stop_message = AnthropicStreamEvent( + type="message_stop", + ) + data = stop_message.model_dump_json( + exclude_unset=True, exclude_none=True + ) + yield wrap_data_with_event(data, "message_stop") + yield "data: [DONE]\n\n" + else: + origin_chunk = ChatCompletionStreamResponse.model_validate_json( + data_str + ) + + if first_item: + chunk = AnthropicStreamEvent( + type="message_start", + message=AnthropicMessagesResponse( + id=origin_chunk.id, + content=[], + model=origin_chunk.model, + ), + ) + first_item = False + data = chunk.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "message_start") + continue + + # last chunk including usage info + if len(origin_chunk.choices) == 0: + if content_block_started: + stop_chunk = AnthropicStreamEvent( + index=content_block_index, + type="content_block_stop", + ) + data = stop_chunk.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "content_block_stop") + stop_reason = self.stop_reason_map.get( + finish_reason or "stop" + ) + chunk = AnthropicStreamEvent( + type="message_delta", + delta=AnthropicDelta(stop_reason=stop_reason), + usage=AnthropicUsage( + input_tokens=origin_chunk.usage.prompt_tokens + if origin_chunk.usage + else 0, + output_tokens=origin_chunk.usage.completion_tokens + if origin_chunk.usage + else 0, + ), + ) + data = chunk.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "message_delta") + continue + + if origin_chunk.choices[0].finish_reason is not None: + finish_reason = origin_chunk.choices[0].finish_reason + continue + + # content + if origin_chunk.choices[0].delta.content is not None: + if not content_block_started: + chunk = AnthropicStreamEvent( + index=content_block_index, + type="content_block_start", + content_block=AnthropicContentBlock( + type="text", text="" + ), + ) + data = chunk.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "content_block_start") + content_block_started = True + + if origin_chunk.choices[0].delta.content == "": + continue + chunk = AnthropicStreamEvent( + index=content_block_index, + type="content_block_delta", + delta=AnthropicDelta( + type="text_delta", + text=origin_chunk.choices[0].delta.content, + ), + ) + data = chunk.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "content_block_delta") + continue + + # tool calls + elif len(origin_chunk.choices[0].delta.tool_calls) > 0: + tool_call = origin_chunk.choices[0].delta.tool_calls[0] + if tool_call.id is not None: + if content_block_started: + stop_chunk = AnthropicStreamEvent( + index=content_block_index, + type="content_block_stop", + ) + data = stop_chunk.model_dump_json( + exclude_unset=True + ) + yield wrap_data_with_event( + data, "content_block_stop" + ) + content_block_started = False + content_block_index += 1 + + chunk = AnthropicStreamEvent( + index=content_block_index, + type="content_block_start", + content_block=AnthropicContentBlock( + type="tool_use", + id=tool_call.id, + name=tool_call.function.name + if tool_call.function + else None, + input={}, + ), + ) + data = chunk.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "content_block_start") + content_block_started = True + + else: + chunk = AnthropicStreamEvent( + index=content_block_index, + type="content_block_delta", + delta=AnthropicDelta( + type="input_json_delta", + partial_json=tool_call.function.arguments + if tool_call.function + else None, + ), + ) + data = chunk.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "content_block_delta") + continue + else: + error_response = AnthropicStreamEvent( + type="error", + error=AnthropicError( + type="internal_error", + message="Invalid data format received", + ), + ) + data = error_response.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "error") + yield "data: [DONE]\n\n" + + except Exception as e: + logger.exception("Error in message stream converter.") + error_response = AnthropicStreamEvent( + type="error", + error=AnthropicError(type="internal_error", message=str(e)), + ) + data = error_response.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "error") + yield "data: [DONE]\n\n" diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index c31d15ddac4f..53dab90f45f7 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -13,7 +13,7 @@ import ssl from argparse import Namespace from collections.abc import AsyncGenerator -from typing import Any, Optional +from typing import Any from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -101,7 +101,7 @@ def build_app(args: Namespace) -> FastAPI: async def init_app( args: Namespace, - llm_engine: Optional[AsyncLLMEngine] = None, + llm_engine: AsyncLLMEngine | None = None, ) -> FastAPI: app = build_app(args) @@ -120,7 +120,7 @@ async def init_app( async def run_server( - args: Namespace, llm_engine: Optional[AsyncLLMEngine] = None, **uvicorn_kwargs: Any + args: Namespace, llm_engine: AsyncLLMEngine | None = None, **uvicorn_kwargs: Any ) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 930b3bc69c3d..4c73e94fb72b 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -5,10 +5,10 @@ import json from abc import ABC, abstractmethod from collections import Counter, defaultdict, deque -from collections.abc import Awaitable, Iterable +from collections.abc import Awaitable, Callable, Iterable from functools import cached_property, lru_cache, partial from pathlib import Path -from typing import Any, Callable, Generic, Literal, Optional, TypeVar, Union, cast +from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast import jinja2 import jinja2.ext @@ -40,7 +40,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin # pydantic needs the TypedDict from typing_extensions -from typing_extensions import Required, TypeAlias, TypedDict +from typing_extensions import Required, TypedDict from vllm.config import ModelConfig from vllm.logger import init_logger @@ -50,7 +50,8 @@ from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import random_uuid, supports_kw +from vllm.utils import random_uuid +from vllm.utils.func_utils import supports_kw logger = init_logger(__name__) @@ -76,7 +77,7 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False): class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): - image_embeds: Optional[Union[str, dict[str, str]]] + image_embeds: str | dict[str, str] | None """ The image embeddings. It can be either: - A single base64 string. @@ -84,7 +85,7 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): """ type: Required[Literal["image_embeds"]] """The type of the content part.""" - uuid: Optional[str] + uuid: str | None """ User-provided UUID of a media. User must guarantee that it is properly generated and unique for different medias. @@ -123,8 +124,8 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False): } """ - image_pil: Optional[PILImage] - uuid: Optional[str] + image_pil: PILImage | None + uuid: str | None """ User-provided UUID of a media. User must guarantee that it is properly generated and unique for different medias. @@ -141,8 +142,8 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): } """ - image_url: Optional[str] - uuid: Optional[str] + image_url: str | None + uuid: str | None """ User-provided UUID of a media. User must guarantee that it is properly generated and unique for different medias. @@ -158,7 +159,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): } """ - audio_url: Optional[str] + audio_url: str | None class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): @@ -170,8 +171,8 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): } """ - video_url: Optional[str] - uuid: Optional[str] + video_url: str | None + uuid: str | None """ User-provided UUID of a media. User must guarantee that it is properly generated and unique for different medias. @@ -199,20 +200,20 @@ class CustomThinkCompletionContentParam(TypedDict, total=False): """The thinking type.""" -ChatCompletionContentPartParam: TypeAlias = Union[ - OpenAIChatCompletionContentPartParam, - ChatCompletionContentPartAudioParam, - ChatCompletionContentPartInputAudioParam, - ChatCompletionContentPartVideoParam, - ChatCompletionContentPartRefusalParam, - CustomChatCompletionContentPILImageParam, - CustomChatCompletionContentSimpleImageParam, - ChatCompletionContentPartImageEmbedsParam, - CustomChatCompletionContentSimpleAudioParam, - CustomChatCompletionContentSimpleVideoParam, - str, - CustomThinkCompletionContentParam, -] +ChatCompletionContentPartParam: TypeAlias = ( + OpenAIChatCompletionContentPartParam + | ChatCompletionContentPartAudioParam + | ChatCompletionContentPartInputAudioParam + | ChatCompletionContentPartVideoParam + | ChatCompletionContentPartRefusalParam + | CustomChatCompletionContentPILImageParam + | CustomChatCompletionContentSimpleImageParam + | ChatCompletionContentPartImageEmbedsParam + | CustomChatCompletionContentSimpleAudioParam + | CustomChatCompletionContentSimpleVideoParam + | str + | CustomThinkCompletionContentParam +) class CustomChatCompletionMessageParam(TypedDict, total=False): @@ -221,7 +222,7 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): role: Required[str] """The role of the message's author.""" - content: Union[str, list[ChatCompletionContentPartParam]] + content: str | list[ChatCompletionContentPartParam] """The contents of the message.""" name: str @@ -231,18 +232,18 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): same role. """ - tool_call_id: Optional[str] + tool_call_id: str | None """Tool call that this message is responding to.""" - tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None """The tool calls generated by the model, such as function calls.""" -ChatCompletionMessageParam = Union[ - OpenAIChatCompletionMessageParam, - CustomChatCompletionMessageParam, - OpenAIHarmonyMessage, -] +ChatCompletionMessageParam: TypeAlias = ( + OpenAIChatCompletionMessageParam + | CustomChatCompletionMessageParam + | OpenAIHarmonyMessage +) # TODO: Make fields ReadOnly once mypy supports it @@ -250,16 +251,16 @@ class ConversationMessage(TypedDict, total=False): role: Required[str] """The role of the message's author.""" - content: Union[Optional[str], list[dict[str, str]]] + content: str | None | list[dict[str, str]] """The contents of the message""" - tool_call_id: Optional[str] + tool_call_id: str | None """Tool call that this message is responding to.""" - name: Optional[str] + name: str | None """The name of the function to call""" - tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None """The tool calls generated by the model, such as function calls.""" @@ -294,7 +295,7 @@ def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: def _is_var_or_elems_access( node: jinja2.nodes.Node, varname: str, - key: Optional[str] = None, + key: str | None = None, ) -> bool: if isinstance(node, jinja2.nodes.Filter): return node.node is not None and _is_var_or_elems_access( @@ -369,7 +370,7 @@ def _iter_nodes_assign_content_item(root: jinja2.nodes.Node): break -def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]: +def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None: try: jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) return jinja_compiled.environment.parse(chat_template) @@ -400,27 +401,19 @@ def _detect_content_format( def resolve_mistral_chat_template( - chat_template: Optional[str], + chat_template: str | None, **kwargs: Any, -) -> Optional[str]: - if chat_template is not None: - logger.warning_once( - "'chat_template' cannot be overridden for mistral tokenizer." - ) - if "add_generation_prompt" in kwargs: - logger.warning_once( - "'add_generation_prompt' is not supported for mistral tokenizer, " - "so it will be ignored." - ) - if "continue_final_message" in kwargs: - logger.warning_once( - "'continue_final_message' is not supported for mistral tokenizer, " - "so it will be ignored." +) -> str | None: + if chat_template is not None or kwargs.get("chat_template_kwargs") is not None: + raise ValueError( + "'chat_template' or 'chat_template_kwargs' cannot be overridden " + "for mistral tokenizer." ) + return None -_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], Optional[str]]() +_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]() """ Used in `_try_get_processor_chat_template` to avoid calling `cached_get_processor` again if the processor fails to be loaded. @@ -430,9 +423,9 @@ def resolve_mistral_chat_template( def _try_get_processor_chat_template( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, model_config: ModelConfig, -) -> Optional[str]: +) -> str | None: cache_key = (tokenizer.name_or_path, model_config.trust_remote_code) if cache_key in _PROCESSOR_CHAT_TEMPLATES: return _PROCESSOR_CHAT_TEMPLATES[cache_key] @@ -466,12 +459,12 @@ def _try_get_processor_chat_template( def resolve_hf_chat_template( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - chat_template: Optional[str], - tools: Optional[list[dict[str, Any]]], + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, + chat_template: str | None, + tools: list[dict[str, Any]] | None, *, model_config: ModelConfig, -) -> Optional[str]: +) -> str | None: # 1st priority: The given chat template if chat_template is not None: return chat_template @@ -513,8 +506,8 @@ def resolve_hf_chat_template( def _resolve_chat_template_content_format( - chat_template: Optional[str], - tools: Optional[list[dict[str, Any]]], + chat_template: str | None, + tools: list[dict[str, Any]] | None, tokenizer: AnyTokenizer, *, model_config: ModelConfig, @@ -546,7 +539,7 @@ def _resolve_chat_template_content_format( @lru_cache def _log_chat_template_content_format( - chat_template: Optional[str], + chat_template: str | None, given_format: ChatTemplateContentFormatOption, detected_format: ChatTemplateContentFormatOption, ): @@ -569,8 +562,8 @@ def _log_chat_template_content_format( def resolve_chat_template_content_format( - chat_template: Optional[str], - tools: Optional[list[dict[str, Any]]], + chat_template: str | None, + tools: list[dict[str, Any]] | None, given_format: ChatTemplateContentFormatOption, tokenizer: AnyTokenizer, *, @@ -612,8 +605,8 @@ def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): self._model_config = model_config self._tokenizer = tokenizer - self._items_by_modality = defaultdict[str, list[Optional[_T]]](list) - self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list) + self._items_by_modality = defaultdict[str, list[_T | None]](list) + self._uuids_by_modality = defaultdict[str, list[str | None]](list) @property def model_config(self) -> ModelConfig: @@ -645,9 +638,9 @@ def mm_processor(self): def add( self, modality: ModalityStr, - item: Optional[_T], - uuid: Optional[str] = None, - ) -> Optional[str]: + item: _T | None, + uuid: str | None = None, + ) -> str | None: """ Add a multi-modal item to the current prompt and returns the placeholder string to use, if any. @@ -665,7 +658,7 @@ def add( return self.model_cls.get_placeholder_str(modality, num_items) - def all_mm_uuids(self) -> Optional[MultiModalUUIDDict]: + def all_mm_uuids(self) -> MultiModalUUIDDict | None: if not self._items_by_modality: return None mm_uuids = {} @@ -692,7 +685,7 @@ def create_parser(self) -> "BaseMultiModalContentParser": class MultiModalItemTracker(BaseMultiModalItemTracker[object]): - def all_mm_data(self) -> Optional[MultiModalDataDict]: + def all_mm_data(self) -> MultiModalDataDict | None: if not self._items_by_modality: return None mm_inputs = {} @@ -718,7 +711,7 @@ def create_parser(self) -> "BaseMultiModalContentParser": class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): - async def all_mm_data(self) -> Optional[MultiModalDataDict]: + async def all_mm_data(self) -> MultiModalDataDict | None: if not self._items_by_modality: return None mm_inputs = {} @@ -764,7 +757,7 @@ def __init__(self) -> None: # } self._placeholder_storage: dict[str, list] = defaultdict(list) - def _add_placeholder(self, modality: ModalityStr, placeholder: Optional[str]): + def _add_placeholder(self, modality: ModalityStr, placeholder: str | None): mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality] if placeholder: self._placeholder_storage[mod_placeholder].append(placeholder) @@ -773,35 +766,35 @@ def mm_placeholder_storage(self) -> dict[str, list]: return dict(self._placeholder_storage) @abstractmethod - def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None: + def parse_image(self, image_url: str | None, uuid: str | None = None) -> None: raise NotImplementedError @abstractmethod def parse_image_embeds( self, - image_embeds: Union[str, dict[str, str], None], - uuid: Optional[str] = None, + image_embeds: str | dict[str, str] | None, + uuid: str | None = None, ) -> None: raise NotImplementedError @abstractmethod def parse_image_pil( - self, image_pil: Optional[Image.Image], uuid: Optional[str] = None + self, image_pil: Image.Image | None, uuid: str | None = None ) -> None: raise NotImplementedError @abstractmethod - def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None: + def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None: raise NotImplementedError @abstractmethod def parse_input_audio( - self, input_audio: Optional[InputAudio], uuid: Optional[str] = None + self, input_audio: InputAudio | None, uuid: str | None = None ) -> None: raise NotImplementedError @abstractmethod - def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None: + def parse_video(self, video_url: str | None, uuid: str | None = None) -> None: raise NotImplementedError @@ -818,7 +811,11 @@ def __init__(self, tracker: MultiModalItemTracker) -> None: allowed_media_domains=tracker.allowed_media_domains, ) - def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None: + @property + def model_config(self) -> ModelConfig: + return self._tracker.model_config + + def parse_image(self, image_url: str | None, uuid: str | None = None) -> None: image = self._connector.fetch_image(image_url) if image_url else None placeholder = self._tracker.add("image", image, uuid) @@ -826,9 +823,15 @@ def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> N def parse_image_embeds( self, - image_embeds: Union[str, dict[str, str], None], - uuid: Optional[str] = None, + image_embeds: str | dict[str, str] | None, + uuid: str | None = None, ) -> None: + mm_config = self.model_config.get_multimodal_config() + if not mm_config.enable_mm_embeds: + raise ValueError( + "You must set `--enable-mm-embeds` to input `image_embeds`" + ) + if isinstance(image_embeds, dict): embeds = { k: self._connector.fetch_image_embedding(v) @@ -846,19 +849,19 @@ def parse_image_embeds( self._add_placeholder("image", placeholder) def parse_image_pil( - self, image_pil: Optional[Image.Image], uuid: Optional[str] = None + self, image_pil: Image.Image | None, uuid: str | None = None ) -> None: placeholder = self._tracker.add("image", image_pil, uuid) self._add_placeholder("image", placeholder) - def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None: + def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None: audio = self._connector.fetch_audio(audio_url) if audio_url else None placeholder = self._tracker.add("audio", audio, uuid) self._add_placeholder("audio", placeholder) def parse_input_audio( - self, input_audio: Optional[InputAudio], uuid: Optional[str] = None + self, input_audio: InputAudio | None, uuid: str | None = None ) -> None: if input_audio: audio_data = input_audio.get("data", "") @@ -873,7 +876,7 @@ def parse_input_audio( return self.parse_audio(audio_url, uuid) - def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None: + def parse_video(self, video_url: str | None, uuid: str | None = None) -> None: video = self._connector.fetch_video(video_url=video_url) if video_url else None placeholder = self._tracker.add("video", video, uuid) @@ -893,7 +896,11 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: allowed_media_domains=tracker.allowed_media_domains, ) - def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None: + @property + def model_config(self) -> ModelConfig: + return self._tracker.model_config + + def parse_image(self, image_url: str | None, uuid: str | None = None) -> None: image_coro = self._connector.fetch_image_async(image_url) if image_url else None placeholder = self._tracker.add("image", image_coro, uuid) @@ -901,10 +908,16 @@ def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> N def parse_image_embeds( self, - image_embeds: Union[str, dict[str, str], None], - uuid: Optional[str] = None, + image_embeds: str | dict[str, str] | None, + uuid: str | None = None, ) -> None: - future: asyncio.Future[Union[str, dict[str, str], None]] = asyncio.Future() + mm_config = self.model_config.get_multimodal_config() + if not mm_config.enable_mm_embeds: + raise ValueError( + "You must set `--enable-mm-embeds` to input `image_embeds`" + ) + + future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future() if isinstance(image_embeds, dict): embeds = { @@ -924,9 +937,9 @@ def parse_image_embeds( self._add_placeholder("image", placeholder) def parse_image_pil( - self, image_pil: Optional[Image.Image], uuid: Optional[str] = None + self, image_pil: Image.Image | None, uuid: str | None = None ) -> None: - future: asyncio.Future[Optional[Image.Image]] = asyncio.Future() + future: asyncio.Future[Image.Image | None] = asyncio.Future() if image_pil: future.set_result(image_pil) else: @@ -935,14 +948,14 @@ def parse_image_pil( placeholder = self._tracker.add("image", future, uuid) self._add_placeholder("image", placeholder) - def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None: + def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None: audio_coro = self._connector.fetch_audio_async(audio_url) if audio_url else None placeholder = self._tracker.add("audio", audio_coro, uuid) self._add_placeholder("audio", placeholder) def parse_input_audio( - self, input_audio: Optional[InputAudio], uuid: Optional[str] = None + self, input_audio: InputAudio | None, uuid: str | None = None ) -> None: if input_audio: audio_data = input_audio.get("data", "") @@ -957,7 +970,7 @@ def parse_input_audio( return self.parse_audio(audio_url, uuid) - def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None: + def parse_video(self, video_url: str | None, uuid: str | None = None) -> None: video = ( self._connector.fetch_video_async(video_url=video_url) if video_url @@ -968,7 +981,7 @@ def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> N self._add_placeholder("video", placeholder) -def validate_chat_template(chat_template: Optional[Union[Path, str]]): +def validate_chat_template(chat_template: Path | str | None): """Raises if the provided chat template appears invalid.""" if chat_template is None: return @@ -992,10 +1005,10 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]): def _load_chat_template( - chat_template: Optional[Union[Path, str]], + chat_template: Path | str | None, *, is_literal: bool = False, -) -> Optional[str]: +) -> str | None: if chat_template is None: return None @@ -1032,10 +1045,10 @@ def _load_chat_template( def load_chat_template( - chat_template: Optional[Union[Path, str]], + chat_template: Path | str | None, *, is_literal: bool = False, -) -> Optional[str]: +) -> str | None: return _cached_load_chat_template(chat_template, is_literal=is_literal) @@ -1115,7 +1128,7 @@ def _get_full_multimodal_text_prompt( _VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python _ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python -_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage] +_ContentPart: TypeAlias = str | dict[str, str] | InputAudio | PILImage # Define a mapping from part types to their corresponding parsing functions. MM_PARSER_MAP: dict[ @@ -1272,7 +1285,7 @@ def _parse_chat_message_content_part( *, wrap_dicts: bool, interleave_strings: bool, -) -> Optional[_ContentPart]: +) -> _ContentPart | None: """Parses a single part of a conversation. If wrap_dicts is True, structured dictionary pieces for texts and images will be wrapped in dictionaries, i.e., {"type": "text", "text", ...} and @@ -1318,10 +1331,7 @@ def _parse_chat_message_content_part( mm_parser.parse_image(str_content, uuid) modality = "image" elif part_type == "image_embeds": - if content is not None: - content = cast(Union[str, dict[str, str]], content) - else: - content = None + content = cast(str | dict[str, str], content) if content is not None else None mm_parser.parse_image_embeds(content, uuid) modality = "image" elif part_type == "audio_url": @@ -1419,8 +1429,8 @@ def parse_chat_messages( content_format: _ChatTemplateContentFormat, ) -> tuple[ list[ConversationMessage], - Optional[MultiModalDataDict], - Optional[MultiModalUUIDDict], + MultiModalDataDict | None, + MultiModalUUIDDict | None, ]: conversation: list[ConversationMessage] = [] mm_tracker = MultiModalItemTracker(model_config, tokenizer) @@ -1451,8 +1461,8 @@ def parse_chat_messages_futures( content_format: _ChatTemplateContentFormat, ) -> tuple[ list[ConversationMessage], - Awaitable[Optional[MultiModalDataDict]], - Optional[MultiModalUUIDDict], + Awaitable[MultiModalDataDict | None], + MultiModalUUIDDict | None, ]: conversation: list[ConversationMessage] = [] mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) @@ -1506,33 +1516,39 @@ def _resolve_chat_template_kwargs( def resolve_chat_template_kwargs( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, chat_template: str, chat_template_kwargs: dict[str, Any], + raise_on_unexpected: bool = True, ) -> dict[str, Any]: + # We exclude chat_template from kwargs here, because + # chat template has been already resolved at this stage + unexpected_vars = {"chat_template", "tokenize"} + if raise_on_unexpected and ( + unexpected_in_kwargs := unexpected_vars & chat_template_kwargs.keys() + ): + raise ValueError( + "Found unexpected chat template kwargs from request: " + f"{unexpected_in_kwargs}" + ) + fn_kw = { k for k in chat_template_kwargs if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False) } - template_vars = _cached_resolve_chat_template_kwargs(chat_template) - - # We exclude chat_template from kwargs here, because - # chat template has been already resolved at this stage - unexpected_vars = {"chat_template"} accept_vars = (fn_kw | template_vars) - unexpected_vars return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars} def apply_hf_chat_template( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, conversation: list[ConversationMessage], - chat_template: Optional[str], - tools: Optional[list[dict[str, Any]]], + chat_template: str | None, + tools: list[dict[str, Any]] | None, *, model_config: ModelConfig, - tokenize: bool = False, # Different from HF's default **kwargs: Any, ) -> str: hf_chat_template = resolve_hf_chat_template( @@ -1549,17 +1565,18 @@ def apply_hf_chat_template( "does not define one." ) + resolved_kwargs = resolve_chat_template_kwargs( + tokenizer=tokenizer, + chat_template=hf_chat_template, + chat_template_kwargs=kwargs, + ) + try: - resolved_kwargs = resolve_chat_template_kwargs( - tokenizer=tokenizer, - chat_template=hf_chat_template, - chat_template_kwargs=kwargs, - ) return tokenizer.apply_chat_template( conversation=conversation, # type: ignore[arg-type] tools=tools, # type: ignore[arg-type] chat_template=hf_chat_template, - tokenize=tokenize, + tokenize=False, **resolved_kwargs, ) @@ -1577,8 +1594,8 @@ def apply_hf_chat_template( def apply_mistral_chat_template( tokenizer: MistralTokenizer, messages: list[ChatCompletionMessageParam], - chat_template: Optional[str], - tools: Optional[list[dict[str, Any]]], + chat_template: str | None, + tools: list[dict[str, Any]] | None, **kwargs: Any, ) -> list[int]: from mistral_common.exceptions import MistralCommonException diff --git a/vllm/entrypoints/cli/benchmark/main.py b/vllm/entrypoints/cli/benchmark/main.py index d7455daa1a6b..7a1d24776009 100644 --- a/vllm/entrypoints/cli/benchmark/main.py +++ b/vllm/entrypoints/cli/benchmark/main.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import argparse import typing @@ -12,6 +10,8 @@ if typing.TYPE_CHECKING: from vllm.utils import FlexibleArgumentParser +else: + FlexibleArgumentParser = argparse.ArgumentParser class BenchmarkSubcommand(CLISubcommand): diff --git a/vllm/entrypoints/cli/collect_env.py b/vllm/entrypoints/cli/collect_env.py index e79a7efec6ba..e47dce0a401a 100644 --- a/vllm/entrypoints/cli/collect_env.py +++ b/vllm/entrypoints/cli/collect_env.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import argparse import typing @@ -11,6 +9,8 @@ if typing.TYPE_CHECKING: from vllm.utils import FlexibleArgumentParser +else: + FlexibleArgumentParser = argparse.ArgumentParser class CollectEnvSubcommand(CLISubcommand): diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py index 0ebfe1c22269..213a46603622 100644 --- a/vllm/entrypoints/cli/main.py +++ b/vllm/entrypoints/cli/main.py @@ -5,9 +5,12 @@ Note that all future modules must be lazily loaded within main to avoid certain eager import breakage.""" -from __future__ import annotations - import importlib.metadata +import sys + +from vllm.logger import init_logger + +logger = init_logger(__name__) def main(): @@ -29,6 +32,22 @@ def main(): cli_env_setup() + # For 'vllm bench *': use CPU instead of UnspecifiedPlatform by default + if len(sys.argv) > 1 and sys.argv[1] == "bench": + logger.debug( + "Bench command detected, must ensure current platform is not " + "UnspecifiedPlatform to avoid device type inference error" + ) + from vllm import platforms + + if platforms.current_platform.is_unspecified(): + from vllm.platforms.cpu import CpuPlatform + + platforms.current_platform = CpuPlatform() + logger.info( + "Unspecified platform detected, switching to CPU Platform instead." + ) + parser = FlexibleArgumentParser( description="vLLM CLI", epilog=VLLM_SUBCMD_PARSER_EPILOG.format(subcmd="[subcommand]"), diff --git a/vllm/entrypoints/cli/openai.py b/vllm/entrypoints/cli/openai.py index 5372210bbf55..a27c6fe6618a 100644 --- a/vllm/entrypoints/cli/openai.py +++ b/vllm/entrypoints/cli/openai.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import argparse import os import signal @@ -16,6 +14,8 @@ if TYPE_CHECKING: from vllm.utils import FlexibleArgumentParser +else: + FlexibleArgumentParser = argparse.ArgumentParser def _register_signal_handlers(): diff --git a/vllm/entrypoints/cli/run_batch.py b/vllm/entrypoints/cli/run_batch.py index 6e7a15ada49c..4b18ceb5215f 100644 --- a/vllm/entrypoints/cli/run_batch.py +++ b/vllm/entrypoints/cli/run_batch.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import argparse import asyncio import importlib.metadata @@ -14,6 +12,8 @@ if typing.TYPE_CHECKING: from vllm.utils import FlexibleArgumentParser +else: + FlexibleArgumentParser = argparse.ArgumentParser logger = init_logger(__name__) diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index b3960b74cf01..e4ba66024135 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -3,7 +3,6 @@ import argparse import signal -from typing import Optional import uvloop @@ -19,15 +18,12 @@ from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import ( - FlexibleArgumentParser, - decorate_logs, - get_tcp_uri, - set_process_title, -) +from vllm.utils import FlexibleArgumentParser +from vllm.utils.network_utils import get_tcp_uri +from vllm.utils.system_utils import decorate_logs, set_process_title from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines -from vllm.v1.executor.abstract import Executor +from vllm.v1.executor import Executor from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure @@ -179,7 +175,7 @@ def run_multi_api_server(args: argparse.Namespace): hybrid_dp_lb = parallel_config.data_parallel_hybrid_lb assert external_dp_lb or hybrid_dp_lb or dp_rank == 0 - api_server_manager: Optional[APIServerProcessManager] = None + api_server_manager: APIServerProcessManager | None = None with launch_core_engines( vllm_config, executor_class, log_stats, num_api_servers diff --git a/vllm/entrypoints/cli/types.py b/vllm/entrypoints/cli/types.py index 6194f421a1bb..f4eeb5b3c2e1 100644 --- a/vllm/entrypoints/cli/types.py +++ b/vllm/entrypoints/cli/types.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import argparse import typing if typing.TYPE_CHECKING: from vllm.utils import FlexibleArgumentParser +else: + FlexibleArgumentParser = argparse.ArgumentParser class CLISubcommand: diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index f410ee9c4045..8886d7c42d8a 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -6,7 +6,7 @@ import logging from abc import ABC, abstractmethod from contextlib import AsyncExitStack -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Union from openai.types.responses.tool import Mcp from openai_harmony import Author, Message, Role, StreamState, TextContent @@ -45,21 +45,36 @@ def _map_tool_name_to_tool_type(tool_name: str) -> str: return _TOOL_NAME_TO_TYPE_MAP[tool_name] -class TurnTokens: - """Tracks token counts for a single conversation turn.""" +class TurnMetrics: + """Tracks token and toolcall details for a single conversation turn.""" - def __init__(self, input_tokens=0, output_tokens=0): + def __init__( + self, + input_tokens=0, + output_tokens=0, + cached_input_tokens=0, + tool_output_tokens=0, + ): self.input_tokens = input_tokens self.output_tokens = output_tokens + self.cached_input_tokens = cached_input_tokens + self.tool_output_tokens = tool_output_tokens def reset(self): """Reset counters for a new turn.""" self.input_tokens = 0 self.output_tokens = 0 + self.cached_input_tokens = 0 + self.tool_output_tokens = 0 def copy(self): """Create a copy of this turn's token counts.""" - return TurnTokens(self.input_tokens, self.output_tokens) + return TurnMetrics( + self.input_tokens, + self.output_tokens, + self.cached_input_tokens, + self.tool_output_tokens, + ) class ConversationContext(ABC): @@ -82,7 +97,7 @@ def render_for_completion(self) -> list[int]: @abstractmethod async def init_tool_sessions( self, - tool_server: Optional[ToolServer], + tool_server: ToolServer | None, exit_stack: AsyncExitStack, request_id: str, mcp_tools: dict[str, Mcp], @@ -102,6 +117,8 @@ def __init__(self): self.num_cached_tokens = 0 # todo num_reasoning_tokens is not implemented yet. self.num_reasoning_tokens = 0 + # not implemented yet for SimpleContext + self.all_turn_metrics = [] def append_output(self, output) -> None: self.last_output = output @@ -122,7 +139,7 @@ def render_for_completion(self) -> list[int]: async def init_tool_sessions( self, - tool_server: Optional[ToolServer], + tool_server: ToolServer | None, exit_stack: AsyncExitStack, request_id: str, mcp_tools: dict[str, Mcp], @@ -140,9 +157,9 @@ def __init__( available_tools: list[str], ): self._messages = messages - self.finish_reason: Optional[str] = None + self.finish_reason: str | None = None self.available_tools = available_tools - self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {} + self._tool_sessions: dict[str, ClientSession | Tool] = {} self.called_tools: set[str] = set() self.parser = get_streamable_parser_for_assistant() @@ -154,8 +171,9 @@ def __init__( self.num_tool_output_tokens = 0 # Turn tracking - replaces multiple individual tracking variables - self.current_turn = TurnTokens() - self.previous_turn = TurnTokens() + self.current_turn_metrics = TurnMetrics() + # Track metrics for all turns + self.all_turn_metrics: list[TurnMetrics] = [] self.is_first_turn = True self.first_tok_of_message = True # For streaming support @@ -164,7 +182,7 @@ def _update_num_reasoning_tokens(self): if self.parser.current_channel in {"analysis", "commentary"}: self.num_reasoning_tokens += 1 - def append_output(self, output: Union[RequestOutput, list[Message]]) -> None: + def append_output(self, output: RequestOutput | list[Message]) -> None: if isinstance(output, RequestOutput): output_token_ids = output.outputs[0].token_ids self.parser = get_streamable_parser_for_assistant() @@ -173,11 +191,10 @@ def append_output(self, output: Union[RequestOutput, list[Message]]) -> None: # Check if the current token is part of reasoning content self._update_num_reasoning_tokens() self._update_prefill_token_usage(output) - # Reset current turn output tokens for this turn - self.current_turn.output_tokens = 0 self._update_decode_token_usage(output) - # Move current turn to previous turn for next turn's calculations - self.previous_turn = self.current_turn.copy() + # Append current turn to all turn list for next turn's calculations + self.all_turn_metrics.append(self.current_turn_metrics.copy()) + self.current_turn_metrics.reset() # append_output is called only once before tool calling # in non-streaming case # so we can append all the parser messages to _messages @@ -213,20 +230,21 @@ def _update_prefill_token_usage(self, output: RequestOutput) -> None: logger.error("RequestOutput appended contains no prompt_token_ids.") # Update current turn input tokens - self.current_turn.input_tokens = this_turn_input_tokens + self.current_turn_metrics.input_tokens = this_turn_input_tokens self.num_prompt_tokens += this_turn_input_tokens # Calculate tool tokens (except on first turn) if self.is_first_turn: self.is_first_turn = False else: + previous_turn = self.all_turn_metrics[-1] # start counting tool after first turn # tool tokens = this turn prefill - last turn prefill - # last turn decode this_turn_tool_tokens = ( - self.current_turn.input_tokens - - self.previous_turn.input_tokens - - self.previous_turn.output_tokens + self.current_turn_metrics.input_tokens + - previous_turn.input_tokens + - previous_turn.output_tokens ) # Handle negative tool token counts (shouldn't happen in normal @@ -237,17 +255,20 @@ def _update_prefill_token_usage(self, output: RequestOutput) -> None: "(current_input=%d, previous_input=%d, " "previous_output=%d). Setting to 0.", this_turn_tool_tokens, - self.current_turn.input_tokens, - self.previous_turn.input_tokens, - self.previous_turn.output_tokens, + self.current_turn_metrics.input_tokens, + previous_turn.input_tokens, + previous_turn.output_tokens, ) this_turn_tool_tokens = 0 self.num_tool_output_tokens += this_turn_tool_tokens + self.current_turn_metrics.tool_output_tokens = this_turn_tool_tokens # Update cached tokens - if output.num_cached_tokens is not None: - self.num_cached_tokens += output.num_cached_tokens + num_cached_token = output.num_cached_tokens + if num_cached_token is not None: + self.num_cached_tokens += num_cached_token + self.current_turn_metrics.cached_input_tokens = num_cached_token def _update_decode_token_usage(self, output: RequestOutput) -> int: """Update token usage statistics for the decode phase of generation. @@ -272,7 +293,7 @@ def _update_decode_token_usage(self, output: RequestOutput) -> int: # only keep last round updated_output_token_count += len(completion_output.token_ids) self.num_output_tokens += updated_output_token_count - self.current_turn.output_tokens += updated_output_token_count + self.current_turn_metrics.output_tokens += updated_output_token_count return updated_output_token_count @property @@ -358,7 +379,7 @@ async def call_python_tool( async def init_tool_sessions( self, - tool_server: Optional[ToolServer], + tool_server: ToolServer | None, exit_stack: AsyncExitStack, request_id: str, mcp_tools: dict[str, Mcp], @@ -446,13 +467,12 @@ def __init__(self, *args, **kwargs): def messages(self) -> list: return self._messages - def append_output(self, output: Union[RequestOutput, list[Message]]) -> None: + def append_output(self, output: RequestOutput | list[Message]) -> None: if isinstance(output, RequestOutput): # append_output is called for each output token in streaming case, # so we only want to add the prompt tokens once for each message. if self.first_tok_of_message: self._update_prefill_token_usage(output) - self.current_turn.output_tokens = 0 # Reset self.first_tok_of_message if needed: # if the current token is the last one of the current message # (finished=True), then the next token processed will mark the @@ -464,7 +484,8 @@ def append_output(self, output: Union[RequestOutput, list[Message]]) -> None: # For streaming, update previous turn when message is complete if output.finished: - self.previous_turn = self.current_turn.copy() + self.all_turn_metrics.append(self.current_turn_metrics.copy()) + self.current_turn_metrics.reset() # Check if the current token is part of reasoning content self._update_num_reasoning_tokens() self.last_tok = tok @@ -494,7 +515,7 @@ def is_assistant_action_turn(self) -> bool: def render_for_completion(self) -> list[int]: # now this list of tokens as next turn's starting tokens - # `<|start|>assistant``, + # `<|start|>assistant`, # we need to process them in parser. rendered_tokens = super().render_for_completion() diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py index bf6cc3e97c82..fe581e5484e1 100644 --- a/vllm/entrypoints/harmony_utils.py +++ b/vllm/entrypoints/harmony_utils.py @@ -1,12 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import datetime import json from collections.abc import Iterable, Sequence -from typing import Literal, Union +from typing import Literal from openai.types.responses import ( ResponseFunctionToolCall, @@ -122,7 +120,7 @@ def get_system_message( return sys_msg -def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]): +def create_tool_definition(tool: ChatCompletionToolsParam | Tool): if isinstance(tool, ChatCompletionToolsParam): return ToolDescription.new( name=tool.function.name, @@ -138,13 +136,13 @@ def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]): def get_developer_message( instructions: str | None = None, - tools: list[Union[Tool, ChatCompletionToolsParam]] | None = None, + tools: list[Tool | ChatCompletionToolsParam] | None = None, ) -> Message: dev_msg_content = DeveloperContent.new() if instructions is not None and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: dev_msg_content = dev_msg_content.with_instructions(instructions) if tools is not None: - function_tools: list[Union[Tool, ChatCompletionToolsParam]] = [] + function_tools: list[Tool | ChatCompletionToolsParam] = [] for tool in tools: if tool.type in ( "web_search_preview", @@ -178,7 +176,7 @@ def get_user_message(content: str) -> Message: def parse_response_input( response_msg: ResponseInputOutputItem, - prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]], + prev_responses: list[ResponseOutputItem | ResponseReasoningItem], ) -> Message: if not isinstance(response_msg, dict): response_msg = response_msg.model_dump() @@ -256,6 +254,15 @@ def parse_chat_input(chat_msg) -> list[Message]: if role == "tool": name = chat_msg.get("name", "") content = chat_msg.get("content", "") or "" + if isinstance(content, list): + # Handle array format for tool message content + # by concatenating all text parts. + content = "".join( + item.get("text", "") + for item in content + if isinstance(item, dict) and item.get("type") == "text" + ) + msg = Message.from_author_and_content( Author.new(Role.TOOL, f"functions.{name}"), content ).with_channel("commentary") diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 349437363c5b..cabf95e8d214 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -5,7 +5,7 @@ import signal import socket from http import HTTPStatus -from typing import Any, Optional +from typing import Any import uvicorn from fastapi import FastAPI, Request, Response @@ -18,7 +18,7 @@ ) from vllm.entrypoints.ssl import SSLCertRefresher from vllm.logger import init_logger -from vllm.utils import find_process_using_port +from vllm.utils.network_utils import find_process_using_port from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError logger = init_logger(__name__) @@ -26,7 +26,7 @@ async def serve_http( app: FastAPI, - sock: Optional[socket.socket], + sock: socket.socket | None, enable_ssl_refresh: bool = False, **uvicorn_kwargs: Any, ): diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c797735f0c2d..c15b70a06809 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools -from collections.abc import Sequence -from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any, cast import cloudpickle import torch.nn as nn @@ -19,18 +19,18 @@ ) from vllm.config import ( CompilationConfig, - ModelDType, + PoolerConfig, StructuredOutputsConfig, - TokenizerMode, is_init_field, ) -from vllm.engine.arg_utils import ( +from vllm.config.model import ( ConvertOption, - EngineArgs, HfOverrides, - PoolerConfig, + ModelDType, RunnerOption, + TokenizerMode, ) +from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, ChatTemplateContentFormatOption, @@ -66,7 +66,6 @@ RequestOutput, ScoringRequestOutput, ) -from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.tasks import PoolingTask @@ -76,10 +75,10 @@ get_cached_tokenizer, ) from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter, Device, as_iter, is_list_of +from vllm.utils import Counter, Device +from vllm.utils.collection_utils import as_iter, is_list_of from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.llm_engine import LLMEngine -from vllm.v1.engine.processor import Processor from vllm.v1.sample.logits_processor import LogitsProcessor if TYPE_CHECKING: @@ -119,9 +118,8 @@ class LLM: execution with tensor parallelism. dtype: The data type for the model weights and activations. Currently, we support `float32`, `float16`, and `bfloat16`. If `auto`, we use - the `torch_dtype` attribute specified in the model config file. - However, if the `torch_dtype` in the config is `float32`, we will - use `float16` instead. + the `dtype` attribute of the Transformers model's config. However, + if the `dtype` in the config is `float32`, we will use `float16` instead. quantization: The method used to quantize the model weights. Currently, we support "awq", "gptq", and "fp8" (experimental). If None, we first check the `quantization_config` attribute in the @@ -178,7 +176,7 @@ class LLM: argument is deprecated and will be removed in v0.12.0 or v1.0.0, whichever is sooner. compilation_config: Either an integer or a dictionary. If it is an - integer, it is used as the level of compilation optimization. If it + integer, it is used as the mode of compilation optimization. If it is a dictionary, it can specify the full compilation configuration. **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs]. @@ -193,36 +191,34 @@ def __init__( *, runner: RunnerOption = "auto", convert: ConvertOption = "auto", - tokenizer: Optional[str] = None, + tokenizer: str | None = None, tokenizer_mode: TokenizerMode = "auto", skip_tokenizer_init: bool = False, trust_remote_code: bool = False, allowed_local_media_path: str = "", - allowed_media_domains: Optional[list[str]] = None, + allowed_media_domains: list[str] | None = None, tensor_parallel_size: int = 1, dtype: ModelDType = "auto", - quantization: Optional[QuantizationMethods] = None, - revision: Optional[str] = None, - tokenizer_revision: Optional[str] = None, - seed: Optional[int] = None, + quantization: QuantizationMethods | None = None, + revision: str | None = None, + tokenizer_revision: str | None = None, + seed: int | None = None, gpu_memory_utilization: float = 0.9, swap_space: float = 4, cpu_offload_gb: float = 0, enforce_eager: bool = False, disable_custom_all_reduce: bool = False, - hf_token: Optional[Union[bool, str]] = None, - hf_overrides: Optional[HfOverrides] = None, - mm_processor_kwargs: Optional[dict[str, Any]] = None, - pooler_config: Optional[PoolerConfig] = None, - override_pooler_config: Optional[PoolerConfig] = None, - structured_outputs_config: Optional[ - Union[dict[str, Any], StructuredOutputsConfig] - ] = None, - kv_cache_memory_bytes: Optional[int] = None, - compilation_config: Optional[ - Union[int, dict[str, Any], CompilationConfig] - ] = None, - logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None, + hf_token: bool | str | None = None, + hf_overrides: HfOverrides | None = None, + mm_processor_kwargs: dict[str, Any] | None = None, + pooler_config: PoolerConfig | None = None, + override_pooler_config: PoolerConfig | None = None, + structured_outputs_config: dict[str, Any] + | StructuredOutputsConfig + | None = None, + kv_cache_memory_bytes: int | None = None, + compilation_config: int | dict[str, Any] | CompilationConfig | None = None, + logits_processors: list[str | type[LogitsProcessor]] | None = None, **kwargs: Any, ) -> None: """LLM constructor.""" @@ -261,9 +257,7 @@ def __init__( if compilation_config is not None: if isinstance(compilation_config, int): - compilation_config_instance = CompilationConfig( - level=compilation_config - ) + compilation_config_instance = CompilationConfig(mode=compilation_config) elif isinstance(compilation_config, dict): compilation_config_instance = CompilationConfig( **{ @@ -291,6 +285,17 @@ def __init__( else: structured_outputs_instance = StructuredOutputsConfig() + # warn about single-process data parallel usage. + _dp_size = int(kwargs.get("data_parallel_size", 1)) + _distributed_executor_backend = kwargs.get("distributed_executor_backend") + if _dp_size > 1 and not _distributed_executor_backend == "external_launcher": + raise ValueError( + f"LLM(data_parallel_size={_dp_size}) is not supported for single-" + "process usage and may hang. Please use " + "the explicit multi-process data-parallel example at " + "'examples/offline_inference/data_parallel.py'." + ) + engine_args = EngineArgs( model=model, runner=runner, @@ -333,23 +338,15 @@ def __init__( self.engine_class = type(self.llm_engine) self.request_counter = Counter() - self.default_sampling_params: Union[dict[str, Any], None] = None - - supported_tasks = self.llm_engine.get_supported_tasks() # type: ignore - - logger.info("Supported_tasks: %s", supported_tasks) + self.default_sampling_params: dict[str, Any] | None = None + supported_tasks = self.llm_engine.get_supported_tasks() + logger.info("Supported tasks: %s", supported_tasks) self.supported_tasks = supported_tasks - # Load the Input/Output processor plugin if any - io_processor_plugin = self.llm_engine.model_config.io_processor_plugin - self.io_processor = get_io_processor( - self.llm_engine.vllm_config, io_processor_plugin - ) - - @property - def model_config(self): - return self.llm_engine.model_config + self.model_config = self.llm_engine.model_config + self.processor = self.llm_engine.processor + self.io_processor = self.llm_engine.io_processor def get_tokenizer(self) -> AnyTokenizer: return self.llm_engine.get_tokenizer() @@ -364,32 +361,25 @@ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: else: self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer) - def _get_processor(self) -> Processor: - if not hasattr(self, "_processor"): - vllm_config = self.llm_engine.vllm_config - self._processor = Processor(vllm_config) - - return self._processor + def reset_mm_cache(self) -> None: + self.processor.clear_mm_cache() + self.llm_engine.reset_mm_cache() def get_default_sampling_params(self) -> SamplingParams: if self.default_sampling_params is None: - self.default_sampling_params = ( - self.llm_engine.model_config.get_diff_sampling_param() - ) + self.default_sampling_params = self.model_config.get_diff_sampling_param() if self.default_sampling_params: return SamplingParams.from_optional(**self.default_sampling_params) return SamplingParams() def generate( self, - prompts: Union[PromptType, Sequence[PromptType]], - sampling_params: Optional[ - Union[SamplingParams, Sequence[SamplingParams]] - ] = None, + prompts: PromptType | Sequence[PromptType], + sampling_params: SamplingParams | Sequence[SamplingParams] | None = None, *, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - priority: Optional[list[int]] = None, + use_tqdm: bool | Callable[..., tqdm] = True, + lora_request: list[LoRARequest] | LoRARequest | None = None, + priority: list[int] | None = None, ) -> list[RequestOutput]: """Generates the completions for the input prompts. @@ -423,7 +413,7 @@ def generate( considered legacy and may be deprecated in the future. You should instead pass them via the `inputs` parameter. """ - model_config = self.llm_engine.model_config + model_config = self.model_config runner_type = model_config.runner_type if runner_type != "generate": raise ValueError( @@ -452,8 +442,8 @@ def generate( def _get_modality_specific_lora_reqs( self, - prompts: Union[PromptType, Sequence[PromptType]], - lora_request: Optional[Union[list[LoRARequest], LoRARequest]], + prompts: PromptType | Sequence[PromptType], + lora_request: list[LoRARequest] | LoRARequest | None, ): # Grab the lora config off the vllm config on the engine, # since this is the same for both v0 & v1. @@ -463,7 +453,7 @@ def _get_modality_specific_lora_reqs( # isn't multimodal, leave the lora as is. if ( lora_config is None - or not self.llm_engine.model_config.is_multimodal_model + or not self.model_config.is_multimodal_model or (lora_config and lora_config.default_mm_loras is None) ): return lora_request @@ -489,21 +479,19 @@ def _get_modality_specific_lora_reqs( def _resolve_single_prompt_mm_lora( self, prompt: PromptType, - lora_request: Optional[LoRARequest], - default_mm_loras: Optional[dict[str, str]], + lora_request: LoRARequest | None, + default_mm_loras: dict[str, str] | None, ): if ( not default_mm_loras or not isinstance(prompt, dict) - or "multi_modal_data" not in prompt + or not (mm_data := prompt.get("multi_modal_data") or {}) ): return lora_request - prompt = cast(Union[TextPrompt, TokensPrompt], prompt) - - intersection = set(prompt["multi_modal_data"].keys()).intersection( - default_mm_loras.keys() - ) + intersection = set( + mm_data.keys() # type: ignore + ).intersection(default_mm_loras.keys()) if not intersection: return lora_request if len(intersection) > 1: @@ -543,10 +531,10 @@ def _resolve_single_prompt_mm_lora( def collective_rpc( self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, + method: str | Callable[..., _R], + timeout: float | None = None, args: tuple = (), - kwargs: Optional[dict[str, Any]] = None, + kwargs: dict[str, Any] | None = None, ) -> list[_R]: """ Execute an RPC call on all workers. @@ -588,9 +576,9 @@ def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: def _get_beam_search_lora_requests( self, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]], - prompts: list[Union[TokensPrompt, TextPrompt]], - ) -> list[Optional[LoRARequest]]: + lora_request: list[LoRARequest] | LoRARequest | None, + prompts: list[TokensPrompt | TextPrompt], + ) -> list[LoRARequest | None]: """Get the optional lora request corresponding to each prompt.""" if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts): raise ValueError( @@ -604,11 +592,11 @@ def _get_beam_search_lora_requests( def beam_search( self, - prompts: list[Union[TokensPrompt, TextPrompt]], + prompts: list[TokensPrompt | TextPrompt], params: BeamSearchParams, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + lora_request: list[LoRARequest] | LoRARequest | None = None, use_tqdm: bool = False, - concurrency_limit: Optional[int] = None, + concurrency_limit: int | None = None, ) -> list[BeamSearchOutput]: """ Generate sequences using beam search. @@ -787,16 +775,15 @@ def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt: def preprocess_chat( self, - messages: Union[ - list[ChatCompletionMessageParam], list[list[ChatCompletionMessageParam]] - ], - chat_template: Optional[str] = None, + messages: list[ChatCompletionMessageParam] + | list[list[ChatCompletionMessageParam]], + chat_template: str | None = None, chat_template_content_format: ChatTemplateContentFormatOption = "auto", add_generation_prompt: bool = True, continue_final_message: bool = False, - tools: Optional[list[dict[str, Any]]] = None, - chat_template_kwargs: Optional[dict[str, Any]] = None, - mm_processor_kwargs: Optional[dict[str, Any]] = None, + tools: list[dict[str, Any]] | None = None, + chat_template_kwargs: dict[str, Any] | None = None, + mm_processor_kwargs: dict[str, Any] | None = None, ) -> list[TokensPrompt]: """ Generate prompt for a chat conversation. The pre-processed @@ -819,7 +806,7 @@ def preprocess_chat( list_of_messages = [cast(list[ChatCompletionMessageParam], messages)] tokenizer = self.get_tokenizer() - model_config = self.llm_engine.get_model_config() + model_config = self.model_config resolved_content_format = resolve_chat_template_content_format( chat_template, tools, @@ -885,19 +872,18 @@ def preprocess_chat( def chat( self, - messages: Union[ - list[ChatCompletionMessageParam], list[list[ChatCompletionMessageParam]] - ], - sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[LoRARequest] = None, - chat_template: Optional[str] = None, + messages: list[ChatCompletionMessageParam] + | list[list[ChatCompletionMessageParam]], + sampling_params: SamplingParams | list[SamplingParams] | None = None, + use_tqdm: bool | Callable[..., tqdm] = True, + lora_request: LoRARequest | None = None, + chat_template: str | None = None, chat_template_content_format: ChatTemplateContentFormatOption = "auto", add_generation_prompt: bool = True, continue_final_message: bool = False, - tools: Optional[list[dict[str, Any]]] = None, - chat_template_kwargs: Optional[dict[str, Any]] = None, - mm_processor_kwargs: Optional[dict[str, Any]] = None, + tools: list[dict[str, Any]] | None = None, + chat_template_kwargs: dict[str, Any] | None = None, + mm_processor_kwargs: dict[str, Any] | None = None, ) -> list[RequestOutput]: """ Generate responses for a chat conversation. @@ -970,14 +956,14 @@ def chat( def encode( self, - prompts: Union[PromptType, Sequence[PromptType], DataPrompt], - pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, + prompts: PromptType | Sequence[PromptType] | DataPrompt, + pooling_params: PoolingParams | Sequence[PoolingParams] | None = None, *, - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, + truncate_prompt_tokens: int | None = None, + use_tqdm: bool | Callable[..., tqdm] = True, + lora_request: list[LoRARequest] | LoRARequest | None = None, + pooling_task: PoolingTask | None = None, + tokenization_kwargs: dict[str, Any] | None = None, ) -> list[PoolingRequestOutput]: """Apply pooling to the hidden states corresponding to the input prompts. @@ -1011,27 +997,26 @@ def encode( instead pass them via the `inputs` parameter. """ - if self.supported_tasks == ["encode"] and pooling_task is None: - pooling_task = "encode" + error_str = ( + "pooling_task required for `LLM.encode`\n" + "Please use one of the more specific methods or set the " + "pooling_task when using `LLM.encode`:\n" + " - For embeddings, use `LLM.embed(...)` " + 'or `pooling_task="embed"`.\n' + " - For classification logits, use `LLM.classify(...)` " + 'or `pooling_task="classify"`.\n' + " - For similarity scores, use `LLM.score(...)`.\n" + " - For rewards, use `LLM.reward(...)` " + 'or `pooling_task="token_classify"`\n' + " - For token classification, " + 'use `pooling_task="token_classify"`\n' + ' - For multi-vector retrieval, use `pooling_task="token_embed"`' + ) if pooling_task is None: - pooling_task = "embed" if "embed" in self.supported_tasks else "encode" - - logger.warning_once( - "`LLM.encode` is currently using `pooling_task = %s`.\n" - "Please use one of the more specific methods or set the " - "task directly when using `LLM.encode`:\n" - " - For embeddings, use `LLM.embed(...)` " - 'or `pooling_task="embed"`.\n' - " - For classification logits, use `LLM.classify(...)` " - 'or `pooling_task="classify"`.\n' - " - For rewards, use `LLM.reward(...)` " - 'or `pooling_task="reward"`\n' - " - For similarity scores, use `LLM.score(...)`.", - pooling_task, - ) + raise ValueError(error_str) - model_config = self.llm_engine.model_config + model_config = self.model_config runner_type = model_config.runner_type if runner_type != "pooling": raise ValueError( @@ -1040,19 +1025,6 @@ def encode( "pooling model." ) - if pooling_task not in self.supported_tasks: - raise ValueError(f"pooling_task must be one of {self.supported_tasks}.") - - if pooling_params is None: - # Use default pooling params. - pooling_params = PoolingParams() - - for param in as_iter(pooling_params): - param.verify(pooling_task, model_config) - # for backwards compatibility - if truncate_prompt_tokens is not None: - param.truncate_prompt_tokens = truncate_prompt_tokens - io_processor_prompt = False if isinstance(prompts, dict) and "data" in prompts: io_processor_prompt = True @@ -1070,6 +1042,34 @@ def encode( # obtain the actual model prompts from the pre-processor prompts = self.io_processor.pre_process(prompt=validated_prompt) + if io_processor_prompt: + assert self.io_processor is not None + if is_list_of(pooling_params, PoolingParams): + validated_pooling_params: list[PoolingParams] = [] + for param in as_iter(pooling_params): + validated_pooling_params.append( + self.io_processor.validate_or_generate_params(param) + ) + pooling_params = validated_pooling_params + else: + assert not isinstance(pooling_params, Sequence) + pooling_params = self.io_processor.validate_or_generate_params( + pooling_params + ) + else: + if pooling_params is None: + # Use default pooling params. + pooling_params = PoolingParams() + + if pooling_task not in self.supported_tasks: + raise ValueError(f"pooling_task must be one of {self.supported_tasks}.") + + for param in as_iter(pooling_params): + param.verify(pooling_task, model_config) + # for backwards compatibility + if truncate_prompt_tokens is not None: + param.truncate_prompt_tokens = truncate_prompt_tokens + self._validate_and_add_requests( prompts=prompts, params=pooling_params, @@ -1094,6 +1094,9 @@ def encode( PoolingRequestOutput[Any]( request_id="", outputs=processed_outputs, + num_cached_tokens=getattr( + processed_outputs, "num_cached_tokens", 0 + ), prompt_token_ids=[], finished=True, ) @@ -1103,12 +1106,12 @@ def encode( def embed( self, - prompts: Union[PromptType, Sequence[PromptType]], + prompts: PromptType | Sequence[PromptType], *, - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + truncate_prompt_tokens: int | None = None, + use_tqdm: bool | Callable[..., tqdm] = True, + pooling_params: PoolingParams | Sequence[PoolingParams] | None = None, + lora_request: list[LoRARequest] | LoRARequest | None = None, ) -> list[EmbeddingRequestOutput]: """ Generate an embedding vector for each prompt. @@ -1152,11 +1155,11 @@ def embed( def classify( self, - prompts: Union[PromptType, Sequence[PromptType]], + prompts: PromptType | Sequence[PromptType], *, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + use_tqdm: bool | Callable[..., tqdm] = True, + pooling_params: PoolingParams | Sequence[PoolingParams] | None = None, + lora_request: list[LoRARequest] | LoRARequest | None = None, ) -> list[ClassificationRequestOutput]: """ Generate class logits for each prompt. @@ -1198,13 +1201,13 @@ def classify( def reward( self, - prompts: Union[PromptType, Sequence[PromptType]], + prompts: PromptType | Sequence[PromptType], /, *, - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + truncate_prompt_tokens: int | None = None, + use_tqdm: bool | Callable[..., tqdm] = True, + pooling_params: PoolingParams | Sequence[PoolingParams] | None = None, + lora_request: list[LoRARequest] | LoRARequest | None = None, ) -> list[PoolingRequestOutput]: """ Generate rewards for each prompt. @@ -1231,18 +1234,18 @@ def reward( lora_request=lora_request, pooling_params=pooling_params, truncate_prompt_tokens=truncate_prompt_tokens, - pooling_task="encode", + pooling_task="token_classify", ) def _embedding_score( self, tokenizer: AnyTokenizer, - text_1: list[Union[str, TextPrompt, TokensPrompt]], - text_2: list[Union[str, TextPrompt, TokensPrompt]], - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - pooling_params: Optional[PoolingParams] = None, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + text_1: list[str | TextPrompt | TokensPrompt], + text_2: list[str | TextPrompt | TokensPrompt], + truncate_prompt_tokens: int | None = None, + use_tqdm: bool | Callable[..., tqdm] = True, + pooling_params: PoolingParams | None = None, + lora_request: list[LoRARequest] | LoRARequest | None = None, ) -> list[ScoringRequestOutput]: encoded_output: list[PoolingRequestOutput] = self.encode( text_1 + text_2, @@ -1269,14 +1272,14 @@ def _embedding_score( def _cross_encoding_score( self, tokenizer: AnyTokenizer, - data_1: Union[list[str], list[ScoreContentPartParam]], - data_2: Union[list[str], list[ScoreContentPartParam]], - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - pooling_params: Optional[PoolingParams] = None, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + data_1: list[str] | list[ScoreContentPartParam], + data_2: list[str] | list[ScoreContentPartParam], + truncate_prompt_tokens: int | None = None, + use_tqdm: bool | Callable[..., tqdm] = True, + pooling_params: PoolingParams | None = None, + lora_request: list[LoRARequest] | LoRARequest | None = None, ) -> list[ScoringRequestOutput]: - model_config = self.llm_engine.model_config + model_config = self.model_config if isinstance(tokenizer, MistralTokenizer): raise ValueError("Score API is not supported for Mistral tokenizer") @@ -1287,7 +1290,6 @@ def _cross_encoding_score( if pooling_params is None: pooling_params = PoolingParams(task="score") - model_config = self.llm_engine.model_config pooling_params.verify("score", model_config) pooling_params_list = list[PoolingParams]() @@ -1301,8 +1303,6 @@ def _cross_encoding_score( input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] - model_config = self.llm_engine.model_config - for q, d in input_pairs: _, engine_prompt = get_score_prompt( model_config=model_config, @@ -1336,14 +1336,14 @@ def _cross_encoding_score( def score( self, - data_1: Union[SingletonPrompt, Sequence[SingletonPrompt], ScoreMultiModalParam], - data_2: Union[SingletonPrompt, Sequence[SingletonPrompt], ScoreMultiModalParam], + data_1: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam, + data_2: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam, /, *, - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - pooling_params: Optional[PoolingParams] = None, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + truncate_prompt_tokens: int | None = None, + use_tqdm: bool | Callable[..., tqdm] = True, + pooling_params: PoolingParams | None = None, + lora_request: list[LoRARequest] | LoRARequest | None = None, ) -> list[ScoringRequestOutput]: """Generate similarity scores for all pairs `<text,text_pair>` or `<multi-modal data, multi-modal data pair>`. @@ -1380,7 +1380,7 @@ def score( A list of `ScoringRequestOutput` objects containing the generated scores in the same order as the input prompts. """ - model_config = self.llm_engine.model_config + model_config = self.model_config runner_type = model_config.runner_type if runner_type != "pooling": raise ValueError( @@ -1411,9 +1411,9 @@ def score( if not model_config.is_multimodal_model: def check_data_type( - data: Union[ - SingletonPrompt, Sequence[SingletonPrompt], ScoreMultiModalParam - ], + data: SingletonPrompt + | Sequence[SingletonPrompt] + | ScoreMultiModalParam, ): if isinstance(data, dict) and "content" in data: raise ValueError( @@ -1490,7 +1490,7 @@ def start_profile(self) -> None: def stop_profile(self) -> None: self.llm_engine.stop_profile() - def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + def reset_prefix_cache(self, device: Device | None = None) -> bool: return self.llm_engine.reset_prefix_cache(device) def sleep(self, level: int = 1): @@ -1515,7 +1515,7 @@ def sleep(self, level: int = 1): self.reset_prefix_cache() self.llm_engine.sleep(level=level) - def wake_up(self, tags: Optional[list[str]] = None): + def wake_up(self, tags: list[str] | None = None): """ Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep] method for more details. @@ -1533,7 +1533,7 @@ def get_metrics(self) -> list["Metric"]: """Return a snapshot of aggregated metrics from Prometheus. Returns: - A ``MetricSnapshot`` instance capturing the current state + A `MetricSnapshot` instance capturing the current state of all aggregated metrics from Prometheus. Note: @@ -1543,17 +1543,15 @@ def get_metrics(self) -> list["Metric"]: def _validate_and_add_requests( self, - prompts: Union[PromptType, Sequence[PromptType], DataPrompt], - params: Union[ - SamplingParams, - Sequence[SamplingParams], - PoolingParams, - Sequence[PoolingParams], - ], + prompts: PromptType | Sequence[PromptType] | DataPrompt, + params: SamplingParams + | Sequence[SamplingParams] + | PoolingParams + | Sequence[PoolingParams], *, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], - priority: Optional[list[int]] = None, + use_tqdm: bool | Callable[..., tqdm] = True, + lora_request: Sequence[LoRARequest] | LoRARequest | None, + priority: list[int] | None = None, ) -> None: if isinstance(prompts, (str, dict)): # Convert a single prompt to a list. @@ -1595,8 +1593,8 @@ def _validate_and_add_requests( def _validate_mm_data_and_uuids( self, - multi_modal_data: Optional[Any], # MultiModalDataDict - multi_modal_uuids: Optional[Any], # MultiModalUUIDDict + multi_modal_data: Any | None, # MultiModalDataDict + multi_modal_uuids: Any | None, # MultiModalUUIDDict ): """ Validate that if any multi-modal data is skipped (i.e. None), @@ -1645,9 +1643,9 @@ def _process_inputs( self, request_id: str, engine_prompt: PromptType, - params: Union[SamplingParams, PoolingParams], + params: SamplingParams | PoolingParams, *, - lora_request: Optional[LoRARequest], + lora_request: LoRARequest | None, priority: int, ) -> tuple[EngineCoreRequest, dict[str, Any]]: """Use the Processor to process inputs for LLMEngine.""" @@ -1658,8 +1656,7 @@ def _process_inputs( tokenization_kwargs, ) - processor = self._get_processor() - engine_request = processor.process_inputs( + engine_request = self.processor.process_inputs( request_id, engine_prompt, params, @@ -1672,8 +1669,8 @@ def _process_inputs( def _add_request( self, prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - lora_request: Optional[LoRARequest] = None, + params: SamplingParams | PoolingParams, + lora_request: LoRARequest | None = None, priority: int = 0, ) -> None: prompt_text, _, _ = get_prompt_components(prompt) @@ -1698,8 +1695,8 @@ def _add_request( ) def _run_engine( - self, *, use_tqdm: Union[bool, Callable[..., tqdm]] = True - ) -> list[Union[RequestOutput, PoolingRequestOutput]]: + self, *, use_tqdm: bool | Callable[..., tqdm] = True + ) -> list[RequestOutput | PoolingRequestOutput]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() @@ -1712,7 +1709,7 @@ def _run_engine( ) # Run the engine. - outputs: list[Union[RequestOutput, PoolingRequestOutput]] = [] + outputs: list[RequestOutput | PoolingRequestOutput] = [] total_in_toks = 0 total_out_toks = 0 while self.llm_engine.has_unfinished_requests(): diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index 96a84668e92b..678a7b3a60b5 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import Optional, Union import torch @@ -15,17 +14,17 @@ class RequestLogger: - def __init__(self, *, max_log_len: Optional[int]) -> None: + def __init__(self, *, max_log_len: int | None) -> None: self.max_log_len = max_log_len def log_inputs( self, request_id: str, - prompt: Optional[str], - prompt_token_ids: Optional[list[int]], - prompt_embeds: Optional[torch.Tensor], - params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]], - lora_request: Optional[LoRARequest], + prompt: str | None, + prompt_token_ids: list[int] | None, + prompt_embeds: torch.Tensor | None, + params: SamplingParams | PoolingParams | BeamSearchParams | None, + lora_request: LoRARequest | None, ) -> None: max_log_len = self.max_log_len if max_log_len is not None: @@ -35,16 +34,20 @@ def log_inputs( if prompt_token_ids is not None: prompt_token_ids = prompt_token_ids[:max_log_len] - logger.info( - "Received request %s: prompt: %r, " - "params: %s, prompt_token_ids: %s, " - "prompt_embeds shape: %s, " - "lora_request: %s.", + logger.debug( + "Request %s details: prompt: %r, " + "prompt_token_ids: %s, " + "prompt_embeds shape: %s.", request_id, prompt, - params, prompt_token_ids, prompt_embeds.shape if prompt_embeds is not None else None, + ) + + logger.info( + "Received request %s: params: %s, lora_request: %s.", + request_id, + params, lora_request, ) @@ -52,8 +55,8 @@ def log_outputs( self, request_id: str, outputs: str, - output_token_ids: Optional[Sequence[int]], - finish_reason: Optional[str] = None, + output_token_ids: Sequence[int] | None, + finish_reason: str | None = None, is_streaming: bool = False, delta: bool = False, ) -> None: diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 889326dee749..29306e45bcf0 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -16,10 +16,10 @@ import tempfile import uuid from argparse import Namespace -from collections.abc import AsyncGenerator, AsyncIterator, Awaitable +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable from contextlib import asynccontextmanager from http import HTTPStatus -from typing import Annotated, Any, Callable, Literal, Optional +from typing import Annotated, Any, Literal import prometheus_client import pydantic @@ -41,11 +41,6 @@ from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import ( - load_chat_template, - resolve_hf_chat_template, - resolve_mistral_chat_template, -) from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args @@ -58,12 +53,14 @@ CompletionResponse, DetokenizeRequest, DetokenizeResponse, + EmbeddingBytesResponse, EmbeddingRequest, EmbeddingResponse, ErrorInfo, ErrorResponse, IOProcessorResponse, LoadLoRAAdapterRequest, + PoolingBytesResponse, PoolingRequest, PoolingResponse, RerankRequest, @@ -88,7 +85,6 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import ( BaseModelPath, - LoRAModulePath, OpenAIServingModels, ) from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling @@ -105,19 +101,16 @@ cli_env_setup, load_aware_call, log_non_default_args, + process_chat_template, + process_lora_modules, with_cancellation, ) from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager -from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils import ( - Device, - FlexibleArgumentParser, - decorate_logs, - is_valid_ipv6_address, - set_ulimit, -) +from vllm.utils import Device, FlexibleArgumentParser, set_ulimit +from vllm.utils.network_utils import is_valid_ipv6_address +from vllm.utils.system_utils import decorate_logs from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.metrics.prometheus import get_prometheus_registry from vllm.version import __version__ as VLLM_VERSION @@ -166,8 +159,8 @@ async def build_async_engine_client( args: Namespace, *, usage_context: UsageContext = UsageContext.OPENAI_API_SERVER, - disable_frontend_multiprocessing: Optional[bool] = None, - client_config: Optional[dict[str, Any]] = None, + disable_frontend_multiprocessing: bool | None = None, + client_config: dict[str, Any] | None = None, ) -> AsyncIterator[EngineClient]: if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver": # The executor is expected to be mp. @@ -203,7 +196,7 @@ async def build_async_engine_client_from_engine_args( *, usage_context: UsageContext = UsageContext.OPENAI_API_SERVER, disable_frontend_multiprocessing: bool = False, - client_config: Optional[dict[str, Any]] = None, + client_config: dict[str, Any] | None = None, ) -> AsyncIterator[EngineClient]: """ Create EngineClient, either: @@ -227,7 +220,7 @@ async def build_async_engine_client_from_engine_args( from vllm.v1.engine.async_llm import AsyncLLM - async_llm: Optional[AsyncLLM] = None + async_llm: AsyncLLM | None = None # Don't mutate the input client_config client_config = dict(client_config) if client_config else {} @@ -239,6 +232,7 @@ async def build_async_engine_client_from_engine_args( vllm_config=vllm_config, usage_context=usage_context, enable_log_requests=engine_args.enable_log_requests, + aggregate_engine_logging=engine_args.aggregate_engine_logging, disable_log_stats=engine_args.disable_log_stats, client_addresses=client_config, client_count=client_count, @@ -308,35 +302,35 @@ def models(request: Request) -> OpenAIServingModels: return request.app.state.openai_serving_models -def responses(request: Request) -> Optional[OpenAIServingResponses]: +def responses(request: Request) -> OpenAIServingResponses | None: return request.app.state.openai_serving_responses -def chat(request: Request) -> Optional[OpenAIServingChat]: +def chat(request: Request) -> OpenAIServingChat | None: return request.app.state.openai_serving_chat -def completion(request: Request) -> Optional[OpenAIServingCompletion]: +def completion(request: Request) -> OpenAIServingCompletion | None: return request.app.state.openai_serving_completion -def pooling(request: Request) -> Optional[OpenAIServingPooling]: +def pooling(request: Request) -> OpenAIServingPooling | None: return request.app.state.openai_serving_pooling -def embedding(request: Request) -> Optional[OpenAIServingEmbedding]: +def embedding(request: Request) -> OpenAIServingEmbedding | None: return request.app.state.openai_serving_embedding -def score(request: Request) -> Optional[ServingScores]: +def score(request: Request) -> ServingScores | None: return request.app.state.openai_serving_scores -def classify(request: Request) -> Optional[ServingClassification]: +def classify(request: Request) -> ServingClassification | None: return request.app.state.openai_serving_classification -def rerank(request: Request) -> Optional[ServingScores]: +def rerank(request: Request) -> ServingScores | None: return request.app.state.openai_serving_scores @@ -542,8 +536,8 @@ async def create_responses(request: ResponsesRequest, raw_request: Request): async def retrieve_responses( response_id: str, raw_request: Request, - starting_after: Optional[int] = None, - stream: Optional[bool] = False, + starting_after: int | None = None, + stream: bool | None = False, ): handler = responses(raw_request) if handler is None: @@ -680,7 +674,10 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ) @with_cancellation @load_aware_call -async def create_embedding(request: EmbeddingRequest, raw_request: Request): +async def create_embedding( + request: EmbeddingRequest, + raw_request: Request, +): handler = embedding(raw_request) if handler is None: return base(raw_request).create_error_response( @@ -700,6 +697,12 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): ) elif isinstance(generator, EmbeddingResponse): return JSONResponse(content=generator.model_dump()) + elif isinstance(generator, EmbeddingBytesResponse): + return StreamingResponse( + content=generator.body, + headers={"metadata": generator.metadata}, + media_type=generator.media_type, + ) assert_never(generator) @@ -732,6 +735,12 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): ) elif isinstance(generator, (PoolingResponse, IOProcessorResponse)): return JSONResponse(content=generator.model_dump()) + elif isinstance(generator, PoolingBytesResponse): + return StreamingResponse( + content=generator.body, + headers={"metadata": generator.metadata}, + media_type=generator.media_type, + ) assert_never(generator) @@ -993,6 +1002,16 @@ async def reset_prefix_cache(raw_request: Request): await engine_client(raw_request).reset_prefix_cache(device) return Response(status_code=200) + @router.post("/reset_mm_cache") + async def reset_mm_cache(raw_request: Request): + """ + Reset the multi-modal cache. Note that we currently do not check if the + multi-modal cache is successfully reset in the API server. + """ + logger.info("Resetting multi-modal cache...") + await engine_client(raw_request).reset_mm_cache() + return Response(status_code=200) + @router.post("/sleep") async def sleep(raw_request: Request): # get POST params @@ -1039,7 +1058,7 @@ async def collective_rpc(raw_request: Request): # User-defined `method` is responsible for deserialization if needed. args: list[str] = body.get("args", []) kwargs: dict[str, str] = body.get("kwargs", {}) - timeout: Optional[float] = body.get("timeout") + timeout: float | None = body.get("timeout") results = await engine_client(raw_request).collective_rpc( method=method, timeout=timeout, args=tuple(args), kwargs=kwargs ) @@ -1120,7 +1139,7 @@ async def is_scaling_elastic_ep(raw_request: Request): # TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers # (requires typing_extensions >= 4.13) RequestType = Any -GetHandlerFn = Callable[[Request], Optional[OpenAIServing]] +GetHandlerFn = Callable[[Request], OpenAIServing | None] EndpointFn = Callable[[RequestType, Request], Awaitable[Any]] # NOTE: Items defined earlier take higher priority @@ -1236,7 +1255,7 @@ async def unload_lora_adapter( return Response(status_code=200, content=response) -def load_log_config(log_config_file: Optional[str]) -> Optional[dict]: +def load_log_config(log_config_file: str | None) -> dict | None: if not log_config_file: return None try: @@ -1601,10 +1620,11 @@ async def log_response(request: Request, call_next): async def init_app_state( engine_client: EngineClient, - vllm_config: VllmConfig, state: State, args: Namespace, ) -> None: + vllm_config = engine_client.vllm_config + if args.served_model_name is not None: served_model_names = args.served_model_name else: @@ -1622,41 +1642,16 @@ async def init_app_state( state.engine_client = engine_client state.log_stats = not args.disable_log_stats state.vllm_config = vllm_config - model_config = vllm_config.model_config supported_tasks = await engine_client.get_supported_tasks() + logger.info("Supported tasks: %s", supported_tasks) - logger.info("Supported_tasks: %s", supported_tasks) - - resolved_chat_template = load_chat_template(args.chat_template) - if resolved_chat_template is not None: - # Get the tokenizer to check official template - tokenizer = await engine_client.get_tokenizer() - - if isinstance(tokenizer, MistralTokenizer): - # The warning is logged in resolve_mistral_chat_template. - resolved_chat_template = resolve_mistral_chat_template( - chat_template=resolved_chat_template - ) - else: - hf_chat_template = resolve_hf_chat_template( - tokenizer=tokenizer, - chat_template=None, - tools=None, - model_config=vllm_config.model_config, - ) - - if hf_chat_template != resolved_chat_template: - logger.warning( - "Using supplied chat template: %s\n" - "It is different from official chat template '%s'. " - "This discrepancy may lead to performance degradation.", - resolved_chat_template, - args.model, - ) + resolved_chat_template = await process_chat_template( + args.chat_template, engine_client, vllm_config.model_config + ) if args.tool_server == "demo": - tool_server: Optional[ToolServer] = DemoToolServer() + tool_server: ToolServer | None = DemoToolServer() assert isinstance(tool_server, DemoToolServer) await tool_server.init_and_validate() elif args.tool_server: @@ -1672,23 +1667,15 @@ async def init_app_state( else {} ) - lora_modules = args.lora_modules - if default_mm_loras: - default_mm_lora_paths = [ - LoRAModulePath( - name=modality, - path=lora_path, - ) - for modality, lora_path in default_mm_loras.items() - ] - if args.lora_modules is None: - lora_modules = default_mm_lora_paths - else: - lora_modules += default_mm_lora_paths + default_mm_loras = ( + vllm_config.lora_config.default_mm_loras + if vllm_config.lora_config is not None + else {} + ) + lora_modules = process_lora_modules(args.lora_modules, default_mm_loras) state.openai_serving_models = OpenAIServingModels( engine_client=engine_client, - model_config=model_config, base_model_paths=base_model_paths, lora_modules=lora_modules, ) @@ -1696,7 +1683,6 @@ async def init_app_state( state.openai_serving_responses = ( OpenAIServingResponses( engine_client, - model_config, state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, @@ -1717,7 +1703,6 @@ async def init_app_state( state.openai_serving_chat = ( OpenAIServingChat( engine_client, - model_config, state.openai_serving_models, args.response_role, request_logger=request_logger, @@ -1740,7 +1725,6 @@ async def init_app_state( state.openai_serving_completion = ( OpenAIServingCompletion( engine_client, - model_config, state.openai_serving_models, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, @@ -1752,23 +1736,29 @@ async def init_app_state( else None ) state.openai_serving_pooling = ( - OpenAIServingPooling( - engine_client, - vllm_config, - state.openai_serving_models, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - trust_request_chat_template=args.trust_request_chat_template, - log_error_stack=args.log_error_stack, + ( + OpenAIServingPooling( + engine_client, + state.openai_serving_models, + supported_tasks=supported_tasks, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, + log_error_stack=args.log_error_stack, + ) + ) + if ( + any( + task in supported_tasks + for task in ["token_embed", "token_classify", "plugin"] + ) ) - if "encode" in supported_tasks else None ) state.openai_serving_embedding = ( OpenAIServingEmbedding( engine_client, - model_config, state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, @@ -1782,7 +1772,6 @@ async def init_app_state( state.openai_serving_classification = ( ServingClassification( engine_client, - model_config, state.openai_serving_models, request_logger=request_logger, log_error_stack=args.log_error_stack, @@ -1793,7 +1782,6 @@ async def init_app_state( state.openai_serving_scores = ( ServingScores( engine_client, - model_config, state.openai_serving_models, request_logger=request_logger, log_error_stack=args.log_error_stack, @@ -1803,7 +1791,6 @@ async def init_app_state( ) state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, - model_config, state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, @@ -1814,10 +1801,10 @@ async def init_app_state( state.openai_serving_transcription = ( OpenAIServingTranscription( engine_client, - model_config, state.openai_serving_models, request_logger=request_logger, log_error_stack=args.log_error_stack, + enable_force_include_usage=args.enable_force_include_usage, ) if "transcription" in supported_tasks else None @@ -1825,10 +1812,10 @@ async def init_app_state( state.openai_serving_translation = ( OpenAIServingTranslation( engine_client, - model_config, state.openai_serving_models, request_logger=request_logger, log_error_stack=args.log_error_stack, + enable_force_include_usage=args.enable_force_include_usage, ) if "transcription" in supported_tasks else None @@ -1946,12 +1933,11 @@ async def run_server_worker( maybe_register_tokenizer_info_endpoint(args) app = build_app(args) - vllm_config = await engine_client.get_vllm_config() - await init_app_state(engine_client, vllm_config, app.state, args) + await init_app_state(engine_client, app.state, args) logger.info( "Starting vLLM API server %d on %s", - vllm_config.parallel_config._api_process_rank, + engine_client.vllm_config.parallel_config._api_process_rank, listen_address, ) shutdown_task = await serve_http( diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 1f16646db63b..99d6cbaa86b8 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -11,7 +11,7 @@ import ssl from collections.abc import Sequence from dataclasses import field -from typing import Literal, Optional, Union +from typing import Literal from pydantic.dataclasses import dataclass @@ -39,8 +39,8 @@ def __call__( self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, - values: Optional[Union[str, Sequence[str]]], - option_string: Optional[str] = None, + values: str | Sequence[str] | None, + option_string: str | None = None, ): if values is None: values = [] @@ -73,11 +73,11 @@ def __call__( class FrontendArgs: """Arguments for the OpenAI-compatible frontend server.""" - host: Optional[str] = None + host: str | None = None """Host name.""" port: int = 8000 """Port number.""" - uds: Optional[str] = None + uds: str | None = None """Unix domain socket path. If set, host and port arguments are ignored.""" uvicorn_log_level: Literal[ "debug", "info", "warning", "error", "critical", "trace" @@ -93,15 +93,15 @@ class FrontendArgs: """Allowed methods.""" allowed_headers: list[str] = field(default_factory=lambda: ["*"]) """Allowed headers.""" - api_key: Optional[list[str]] = None + api_key: list[str] | None = None """If provided, the server will require one of these keys to be presented in the header.""" - lora_modules: Optional[list[LoRAModulePath]] = None + lora_modules: list[LoRAModulePath] | None = None """LoRA modules configurations in either 'name=path' format or JSON format or JSON list format. Example (old format): `'name=path'` Example (new format): `{\"name\": \"name\", \"path\": \"lora_path\", \"base_model_name\": \"id\"}`""" - chat_template: Optional[str] = None + chat_template: str | None = None """The file path to the chat template, or the template in single-line form for the specified model.""" chat_template_content_format: ChatTemplateContentFormatOption = "auto" @@ -116,17 +116,17 @@ class FrontendArgs: or the ones from tokenizer.""" response_role: str = "assistant" """The role name to return if `request.add_generation_prompt=true`.""" - ssl_keyfile: Optional[str] = None + ssl_keyfile: str | None = None """The file path to the SSL key file.""" - ssl_certfile: Optional[str] = None + ssl_certfile: str | None = None """The file path to the SSL cert file.""" - ssl_ca_certs: Optional[str] = None + ssl_ca_certs: str | None = None """The CA certificates file.""" enable_ssl_refresh: bool = False """Refresh SSL Context when SSL certificate files change""" ssl_cert_reqs: int = int(ssl.CERT_NONE) """Whether client certificate is required (see stdlib ssl module's).""" - root_path: Optional[str] = None + root_path: str | None = None """FastAPI root_path when app is behind a path based routing proxy.""" middleware: list[str] = field(default_factory=lambda: []) """Additional ASGI middleware to apply to the app. We accept multiple @@ -149,7 +149,7 @@ class FrontendArgs: exclude_tools_when_tool_choice_none: bool = False """If specified, exclude tool definitions in prompts when tool_choice='none'.""" - tool_call_parser: Optional[str] = None + tool_call_parser: str | None = None """Select the tool call parser depending on the model that you're using. This is used to parse the model-generated tool call into OpenAI API format. Required for `--enable-auto-tool-choice`. You can choose any option from @@ -158,13 +158,13 @@ class FrontendArgs: """Special the tool parser plugin write to parse the model-generated tool into OpenAI API format, the name register in this plugin can be used in `--tool-call-parser`.""" - tool_server: Optional[str] = None + tool_server: str | None = None """Comma-separated list of host:port pairs (IPv4, IPv6, or hostname). Examples: 127.0.0.1:8000, [::1]:8000, localhost:1234. Or `demo` for demo purpose.""" - log_config_file: Optional[str] = envs.VLLM_LOGGING_CONFIG_PATH + log_config_file: str | None = envs.VLLM_LOGGING_CONFIG_PATH """Path to logging config JSON file for both vllm and uvicorn""" - max_log_len: Optional[int] = None + max_log_len: int | None = None """Max number of prompt characters or prompt ID numbers being printed in log. The default of None means unlimited.""" disable_fastapi_docs: bool = False diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index 2ea9fbf386ba..dedbc23ec83f 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -3,7 +3,6 @@ from collections.abc import Iterable from functools import lru_cache, partial -from typing import Optional, Union import torch @@ -16,8 +15,8 @@ class AllowedTokenIdsLogitsProcessor: specific set of token ids.""" def __init__(self, allowed_ids: Iterable[int]): - self.allowed_ids: Optional[list[int]] = list(allowed_ids) - self.mask: Optional[torch.Tensor] = None + self.allowed_ids: list[int] | None = list(allowed_ids) + self.mask: torch.Tensor | None = None def __call__(self, token_ids: list[int], logits: torch.Tensor) -> torch.Tensor: if self.mask is None: @@ -53,8 +52,8 @@ def logit_bias_logits_processor( def get_logits_processors( - logit_bias: Optional[Union[dict[int, float], dict[str, float]]], - allowed_token_ids: Optional[list[int]], + logit_bias: dict[int, float] | dict[str, float] | None, + allowed_token_ids: list[int] | None, tokenizer: AnyTokenizer, ) -> list[LogitsProcessor]: logits_processors: list[LogitsProcessor] = [] diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 6ff7ceef4805..9782641296d6 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -6,7 +6,7 @@ import json import time from http import HTTPStatus -from typing import Annotated, Any, ClassVar, Generic, Literal, Optional, TypeVar, Union +from typing import Annotated, Any, ClassVar, Generic, Literal, TypeAlias, TypeVar import regex as re import torch @@ -48,12 +48,19 @@ Content as ResponseReasoningTextContent, ) +from vllm.utils.serial_utils import ( + EmbedDType, + EncodingFormat, + Endianness, +) + # Backward compatibility for OpenAI client versions try: # For older openai versions (< 1.100.0) from openai.types.responses import ResponseTextConfig except ImportError: # For newer openai versions (>= 1.100.0) from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig + from openai.types.responses.response import IncompleteDetails, ToolChoice from openai.types.responses.tool import Tool from openai.types.shared import Metadata, Reasoning @@ -62,12 +69,12 @@ ConfigDict, Field, TypeAdapter, + ValidationError, ValidationInfo, field_serializer, field_validator, model_validator, ) -from typing_extensions import TypeAlias from vllm import envs from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id @@ -81,7 +88,8 @@ SamplingParams, StructuredOutputsParams, ) -from vllm.utils import random_uuid, resolve_obj_by_qualname +from vllm.utils import random_uuid +from vllm.utils.import_utils import resolve_obj_by_qualname logger = init_logger(__name__) @@ -93,7 +101,7 @@ class OpenAIBaseModel(BaseModel): model_config = ConfigDict(extra="allow") # Cache class field names - field_names: ClassVar[Optional[set[str]]] = None + field_names: ClassVar[set[str] | None] = None @model_validator(mode="wrap") @classmethod @@ -123,7 +131,7 @@ def __log_extra_fields__(cls, data, handler): class ErrorInfo(OpenAIBaseModel): message: str type: str - param: Optional[str] = None + param: str | None = None code: int @@ -142,7 +150,7 @@ class ModelPermission(OpenAIBaseModel): allow_view: bool = True allow_fine_tuning: bool = False organization: str = "*" - group: Optional[str] = None + group: str | None = None is_blocking: bool = False @@ -151,9 +159,9 @@ class ModelCard(OpenAIBaseModel): object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) owned_by: str = "vllm" - root: Optional[str] = None - parent: Optional[str] = None - max_model_len: Optional[int] = None + root: str | None = None + parent: str | None = None + max_model_len: int | None = None permission: list[ModelPermission] = Field(default_factory=list) @@ -163,64 +171,74 @@ class ModelList(OpenAIBaseModel): class PromptTokenUsageInfo(OpenAIBaseModel): - cached_tokens: Optional[int] = None + cached_tokens: int | None = None class UsageInfo(OpenAIBaseModel): prompt_tokens: int = 0 total_tokens: int = 0 - completion_tokens: Optional[int] = 0 - prompt_tokens_details: Optional[PromptTokenUsageInfo] = None + completion_tokens: int | None = 0 + prompt_tokens_details: PromptTokenUsageInfo | None = None class RequestResponseMetadata(BaseModel): request_id: str - final_usage_info: Optional[UsageInfo] = None + final_usage_info: UsageInfo | None = None class JsonSchemaResponseFormat(OpenAIBaseModel): name: str - description: Optional[str] = None + description: str | None = None # schema is the field in openai but that causes conflicts with pydantic so # instead use json_schema with an alias - json_schema: Optional[dict[str, Any]] = Field(default=None, alias="schema") - strict: Optional[bool] = None + json_schema: dict[str, Any] | None = Field(default=None, alias="schema") + strict: bool | None = None -class StructuralTag(OpenAIBaseModel): +class LegacyStructuralTag(OpenAIBaseModel): begin: str # schema is the field, but that causes conflicts with pydantic so # instead use structural_tag_schema with an alias - structural_tag_schema: Optional[dict[str, Any]] = Field( - default=None, alias="schema" - ) + structural_tag_schema: dict[str, Any] | None = Field(default=None, alias="schema") end: str -class StructuralTagResponseFormat(OpenAIBaseModel): +class LegacyStructuralTagResponseFormat(OpenAIBaseModel): type: Literal["structural_tag"] - structures: list[StructuralTag] + structures: list[LegacyStructuralTag] triggers: list[str] +class StructuralTagResponseFormat(OpenAIBaseModel): + type: Literal["structural_tag"] + format: Any + + +AnyStructuralTagResponseFormat: TypeAlias = ( + LegacyStructuralTagResponseFormat | StructuralTagResponseFormat +) + + class ResponseFormat(OpenAIBaseModel): # type must be "json_schema", "json_object", or "text" type: Literal["text", "json_object", "json_schema"] - json_schema: Optional[JsonSchemaResponseFormat] = None + json_schema: JsonSchemaResponseFormat | None = None -AnyResponseFormat = Union[ResponseFormat, StructuralTagResponseFormat] +AnyResponseFormat: TypeAlias = ( + ResponseFormat | StructuralTagResponseFormat | LegacyStructuralTagResponseFormat +) class StreamOptions(OpenAIBaseModel): - include_usage: Optional[bool] = True - continuous_usage_stats: Optional[bool] = False + include_usage: bool | None = True + continuous_usage_stats: bool | None = False class FunctionDefinition(OpenAIBaseModel): name: str - description: Optional[str] = None - parameters: Optional[dict[str, Any]] = None + description: str | None = None + parameters: dict[str, Any] | None = None class ChatCompletionToolsParam(OpenAIBaseModel): @@ -241,18 +259,18 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): # see https://github.com/pydantic/pydantic/issues/3125 class LogitsProcessorConstructor(BaseModel): qualname: str - args: Optional[list[Any]] = None - kwargs: Optional[dict[str, Any]] = None + args: list[Any] | None = None + kwargs: dict[str, Any] | None = None model_config = ConfigDict(extra="forbid") -LogitsProcessors = list[Union[str, LogitsProcessorConstructor]] +LogitsProcessors = list[str | LogitsProcessorConstructor] def get_logits_processors( - processors: Optional[LogitsProcessors], pattern: Optional[str] -) -> Optional[list[Any]]: + processors: LogitsProcessors | None, pattern: str | None +) -> list[Any] | None: if processors and pattern: logits_processors = [] for processor in processors: @@ -284,16 +302,16 @@ def get_logits_processors( return None -ResponseInputOutputItem: TypeAlias = Union[ - ResponseInputItemParam, ResponseReasoningItem, ResponseFunctionToolCall -] +ResponseInputOutputItem: TypeAlias = ( + ResponseInputItemParam | ResponseReasoningItem | ResponseFunctionToolCall +) class ResponsesRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/responses/create - background: Optional[bool] = False - include: Optional[ + background: bool | None = False + include: ( list[ Literal[ "code_interpreter_call.outputs", @@ -304,28 +322,29 @@ class ResponsesRequest(OpenAIBaseModel): "reasoning.encrypted_content", ], ] - ] = None - input: Union[str, list[ResponseInputOutputItem]] - instructions: Optional[str] = None - max_output_tokens: Optional[int] = None - max_tool_calls: Optional[int] = None - metadata: Optional[Metadata] = None - model: Optional[str] = None - parallel_tool_calls: Optional[bool] = True - previous_response_id: Optional[str] = None - prompt: Optional[ResponsePrompt] = None - reasoning: Optional[Reasoning] = None + | None + ) = None + input: str | list[ResponseInputOutputItem] + instructions: str | None = None + max_output_tokens: int | None = None + max_tool_calls: int | None = None + metadata: Metadata | None = None + model: str | None = None + parallel_tool_calls: bool | None = True + previous_response_id: str | None = None + prompt: ResponsePrompt | None = None + reasoning: Reasoning | None = None service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto" - store: Optional[bool] = True - stream: Optional[bool] = False - temperature: Optional[float] = None - text: Optional[ResponseTextConfig] = None + store: bool | None = True + stream: bool | None = False + temperature: float | None = None + text: ResponseTextConfig | None = None tool_choice: ToolChoice = "auto" tools: list[Tool] = Field(default_factory=list) - top_logprobs: Optional[int] = 0 - top_p: Optional[float] = None - truncation: Optional[Literal["auto", "disabled"]] = "disabled" - user: Optional[str] = None + top_logprobs: int | None = 0 + top_p: float | None = None + truncation: Literal["auto", "disabled"] | None = "disabled" + user: str | None = None # --8<-- [start:responses-extra-params] request_id: str = Field( @@ -336,7 +355,7 @@ class ResponsesRequest(OpenAIBaseModel): "through out the inference process and return in response." ), ) - mm_processor_kwargs: Optional[dict[str, Any]] = Field( + mm_processor_kwargs: dict[str, Any] | None = Field( default=None, description=("Additional kwargs to pass to the HF processor."), ) @@ -348,7 +367,7 @@ class ResponsesRequest(OpenAIBaseModel): "if the served model does not use priority scheduling." ), ) - cache_salt: Optional[str] = Field( + cache_salt: str | None = Field( default=None, description=( "If specified, the prefix cache will be salted with the provided " @@ -378,7 +397,7 @@ class ResponsesRequest(OpenAIBaseModel): def to_sampling_params( self, default_max_tokens: int, - default_sampling_params: Optional[dict] = None, + default_sampling_params: dict | None = None, ) -> SamplingParams: if self.max_output_tokens is None: max_tokens = default_max_tokens @@ -460,63 +479,104 @@ def check_cache_salt_support(cls, data): ) return data + @model_validator(mode="before") + def function_call_parsing(cls, data): + """Parse function_call dictionaries into ResponseFunctionToolCall objects. + This ensures Pydantic can properly resolve union types in the input field. + Function calls provided as dicts are converted to ResponseFunctionToolCall + objects before validation, while invalid structures are left for Pydantic + to reject with appropriate error messages. + """ + + input_data = data.get("input") + + # Early return for None, strings, or bytes + # (strings are iterable but shouldn't be processed) + if input_data is None or isinstance(input_data, (str, bytes)): + return data + + # Convert iterators (like ValidatorIterator) to list + if not isinstance(input_data, list): + try: + input_data = list(input_data) + except TypeError: + # Not iterable, leave as-is for Pydantic to handle + return data + + processed_input = [] + for item in input_data: + if isinstance(item, dict) and item.get("type") == "function_call": + try: + processed_input.append(ResponseFunctionToolCall(**item)) + except ValidationError: + # Let Pydantic handle validation for malformed function calls + logger.debug( + "Failed to parse function_call to ResponseFunctionToolCall, " + "leaving for Pydantic validation" + ) + processed_input.append(item) + else: + processed_input.append(item) + + data["input"] = processed_input + return data + class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create messages: list[ChatCompletionMessageParam] - model: Optional[str] = None - frequency_penalty: Optional[float] = 0.0 - logit_bias: Optional[dict[str, float]] = None - logprobs: Optional[bool] = False - top_logprobs: Optional[int] = 0 - max_tokens: Optional[int] = Field( + model: str | None = None + frequency_penalty: float | None = 0.0 + logit_bias: dict[str, float] | None = None + logprobs: bool | None = False + top_logprobs: int | None = 0 + max_tokens: int | None = Field( default=None, deprecated="max_tokens is deprecated in favor of " "the max_completion_tokens field", ) - max_completion_tokens: Optional[int] = None - n: Optional[int] = 1 - presence_penalty: Optional[float] = 0.0 - response_format: Optional[AnyResponseFormat] = None - seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - stop: Optional[Union[str, list[str]]] = [] - stream: Optional[bool] = False - stream_options: Optional[StreamOptions] = None - temperature: Optional[float] = None - top_p: Optional[float] = None - tools: Optional[list[ChatCompletionToolsParam]] = None - tool_choice: Optional[ - Union[ - Literal["none"], - Literal["auto"], - Literal["required"], - ChatCompletionNamedToolChoiceParam, - ] - ] = "none" - reasoning_effort: Optional[Literal["low", "medium", "high"]] = None + max_completion_tokens: int | None = None + n: int | None = 1 + presence_penalty: float | None = 0.0 + response_format: AnyResponseFormat | None = None + seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + stop: str | list[str] | None = [] + stream: bool | None = False + stream_options: StreamOptions | None = None + temperature: float | None = None + top_p: float | None = None + tools: list[ChatCompletionToolsParam] | None = None + tool_choice: ( + Literal["none"] + | Literal["auto"] + | Literal["required"] + | ChatCompletionNamedToolChoiceParam + | None + ) = "none" + reasoning_effort: Literal["low", "medium", "high"] | None = None include_reasoning: bool = True # NOTE this will be ignored by vLLM -- the model determines the behavior - parallel_tool_calls: Optional[bool] = False - user: Optional[str] = None + parallel_tool_calls: bool | None = False + user: str | None = None # --8<-- [start:chat-completion-sampling-params] - best_of: Optional[int] = None + best_of: int | None = None use_beam_search: bool = False - top_k: Optional[int] = None - min_p: Optional[float] = None - repetition_penalty: Optional[float] = None + top_k: int | None = None + min_p: float | None = None + repetition_penalty: float | None = None length_penalty: float = 1.0 - stop_token_ids: Optional[list[int]] = [] + stop_token_ids: list[int] | None = [] include_stop_str_in_output: bool = False ignore_eos: bool = False min_tokens: int = 0 skip_special_tokens: bool = True spaces_between_special_tokens: bool = True - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None - prompt_logprobs: Optional[int] = None - allowed_token_ids: Optional[list[int]] = None + truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None + prompt_logprobs: int | None = None + allowed_token_ids: list[int] | None = None bad_words: list[str] = Field(default_factory=list) # --8<-- [end:chat-completion-sampling-params] @@ -556,7 +616,7 @@ class ChatCompletionRequest(OpenAIBaseModel): "default)." ), ) - documents: Optional[list[dict[str, str]]] = Field( + documents: list[dict[str, str]] | None = Field( default=None, description=( "A list of dicts representing documents that will be accessible to " @@ -566,7 +626,7 @@ class ChatCompletionRequest(OpenAIBaseModel): '"title" and "text" keys.' ), ) - chat_template: Optional[str] = Field( + chat_template: str | None = Field( default=None, description=( "A Jinja template to use for this conversion. " @@ -575,22 +635,22 @@ class ChatCompletionRequest(OpenAIBaseModel): "does not define one." ), ) - chat_template_kwargs: Optional[dict[str, Any]] = Field( + chat_template_kwargs: dict[str, Any] | None = Field( default=None, description=( "Additional keyword args to pass to the template renderer. " "Will be accessible by the chat template." ), ) - mm_processor_kwargs: Optional[dict[str, Any]] = Field( + mm_processor_kwargs: dict[str, Any] | None = Field( default=None, description=("Additional kwargs to pass to the HF processor."), ) - structured_outputs: Optional[StructuredOutputsParams] = Field( + structured_outputs: StructuredOutputsParams | None = Field( default=None, description="Additional kwargs for structured outputs", ) - guided_json: Optional[Union[str, dict, BaseModel]] = Field( + guided_json: str | dict | BaseModel | None = Field( default=None, description=( "`guided_json` is deprecated. " @@ -598,7 +658,7 @@ class ChatCompletionRequest(OpenAIBaseModel): "Please pass `json` to `structured_outputs` instead." ), ) - guided_regex: Optional[str] = Field( + guided_regex: str | None = Field( default=None, description=( "`guided_regex` is deprecated. " @@ -606,7 +666,7 @@ class ChatCompletionRequest(OpenAIBaseModel): "Please pass `regex` to `structured_outputs` instead." ), ) - guided_choice: Optional[list[str]] = Field( + guided_choice: list[str] | None = Field( default=None, description=( "`guided_choice` is deprecated. " @@ -614,7 +674,7 @@ class ChatCompletionRequest(OpenAIBaseModel): "Please pass `choice` to `structured_outputs` instead." ), ) - guided_grammar: Optional[str] = Field( + guided_grammar: str | None = Field( default=None, description=( "`guided_grammar` is deprecated. " @@ -622,7 +682,7 @@ class ChatCompletionRequest(OpenAIBaseModel): "Please pass `grammar` to `structured_outputs` instead." ), ) - structural_tag: Optional[str] = Field( + structural_tag: str | None = Field( default=None, description=( "`structural_tag` is deprecated. " @@ -630,7 +690,7 @@ class ChatCompletionRequest(OpenAIBaseModel): "Please pass `structural_tag` to `structured_outputs` instead." ), ) - guided_decoding_backend: Optional[str] = Field( + guided_decoding_backend: str | None = Field( default=None, description=( "`guided_decoding_backend` is deprecated. " @@ -638,7 +698,7 @@ class ChatCompletionRequest(OpenAIBaseModel): "Please remove it from your request." ), ) - guided_whitespace_pattern: Optional[str] = Field( + guided_whitespace_pattern: str | None = Field( default=None, description=( "`guided_whitespace_pattern` is deprecated. " @@ -662,7 +722,7 @@ class ChatCompletionRequest(OpenAIBaseModel): "through out the inference process and return in response." ), ) - logits_processors: Optional[LogitsProcessors] = Field( + logits_processors: LogitsProcessors | None = Field( default=None, description=( "A list of either qualified names of logits processors, or " @@ -675,7 +735,7 @@ class ChatCompletionRequest(OpenAIBaseModel): "{'param': 'value'}}." ), ) - return_tokens_as_token_ids: Optional[bool] = Field( + return_tokens_as_token_ids: bool | None = Field( default=None, description=( "If specified with 'logprobs', tokens are represented " @@ -683,7 +743,7 @@ class ChatCompletionRequest(OpenAIBaseModel): "that are not JSON-encodable can be identified." ), ) - return_token_ids: Optional[bool] = Field( + return_token_ids: bool | None = Field( default=None, description=( "If specified, the result will include token IDs alongside the " @@ -693,7 +753,7 @@ class ChatCompletionRequest(OpenAIBaseModel): "need to map generated text back to input tokens." ), ) - cache_salt: Optional[str] = Field( + cache_salt: str | None = Field( default=None, description=( "If specified, the prefix cache will be salted with the provided " @@ -704,12 +764,12 @@ class ChatCompletionRequest(OpenAIBaseModel): "to 256 bit). Not supported by vLLM engine V0." ), ) - kv_transfer_params: Optional[dict[str, Any]] = Field( + kv_transfer_params: dict[str, Any] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) - vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( + vllm_xargs: dict[str, str | int | float] | None = Field( default=None, description=( "Additional request parameters with string or " @@ -749,7 +809,7 @@ def to_beam_search_params( def to_sampling_params( self, max_tokens: int, - logits_processor_pattern: Optional[str], + logits_processor_pattern: str | None, default_sampling_params: dict, ) -> SamplingParams: # Default parameters @@ -794,8 +854,7 @@ def to_sampling_params( self.structured_outputs = StructuredOutputsParams(**kwargs) response_format = self.response_format - json_schema_from_tool = self._get_json_schema_from_tool() - if response_format is not None or json_schema_from_tool is not None: + if response_format is not None: # If structured outputs wasn't already enabled, # we must enable it for these features to work if self.structured_outputs is None: @@ -812,15 +871,15 @@ def to_sampling_params( elif response_format.type == "structural_tag": structural_tag = response_format assert structural_tag is not None and isinstance( - structural_tag, StructuralTagResponseFormat + structural_tag, + ( + LegacyStructuralTagResponseFormat, + StructuralTagResponseFormat, + ), ) s_tag_obj = structural_tag.model_dump(by_alias=True) self.structured_outputs.structural_tag = json.dumps(s_tag_obj) - # Set structured output params for tool calling - if json_schema_from_tool is not None: - self.structured_outputs.json = json_schema_from_tool - extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} if self.kv_transfer_params: # Pass in kv_transfer_params via extra_args @@ -860,72 +919,6 @@ def to_sampling_params( extra_args=extra_args or None, ) - def _get_json_schema_from_tool(self) -> Optional[Union[str, dict]]: - # user has chosen to not use any tool - if self.tool_choice == "none" or self.tools is None: - return None - - # user has chosen to use a named tool - if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam: - tool_name = self.tool_choice.function.name - tools = {tool.function.name: tool.function for tool in self.tools} - if tool_name not in tools: - raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.") - tool = tools[tool_name] - return tool.parameters - - if self.tool_choice == "required": - # Pydantic schema generation cannot be used since the JSON schema - # has to be constructed for a specific instantiation of a tool list - # so that parameters of a function are correctly generated - # based on the chosen function name - def get_tool_schema(tool: ChatCompletionToolsParam) -> dict: - return { - "properties": { - "name": {"type": "string", "enum": [tool.function.name]}, - # parameters are always generated as '{}' in the final - # output if they are missing from the request - # (i.e. are None or '{}') so the schema is - # updated to produce an empty object in that case - "parameters": tool.function.parameters - if tool.function.parameters - else {"type": "object", "properties": {}}, - }, - "required": ["name", "parameters"], - } - - def get_tool_schema_defs(tools: list[ChatCompletionToolsParam]) -> dict: - all_defs = dict[str, dict[str, Any]]() - for tool in tools: - if tool.function.parameters is None: - continue - defs = tool.function.parameters.pop("$defs", {}) - for def_name, def_schema in defs.items(): - if def_name in all_defs and all_defs[def_name] != def_schema: - raise ValueError( - f"Tool definition '{def_name}' has " - "multiple schemas, which is not " - "supported." - ) - else: - all_defs[def_name] = def_schema - return all_defs - - json_schema = { - "type": "array", - "minItems": 1, - "items": { - "type": "object", - "anyOf": [get_tool_schema(tool) for tool in self.tools], - }, - } - json_schema_defs = get_tool_schema_defs(self.tools) - if json_schema_defs: - json_schema["$defs"] = json_schema_defs - return json_schema - - return None - @model_validator(mode="before") @classmethod def validate_stream_options(cls, data): @@ -1098,44 +1091,44 @@ def check_cache_salt_support(cls, data): class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create - model: Optional[str] = None - prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None - best_of: Optional[int] = None - echo: Optional[bool] = False - frequency_penalty: Optional[float] = 0.0 - logit_bias: Optional[dict[str, float]] = None - logprobs: Optional[int] = None - max_tokens: Optional[int] = 16 + model: str | None = None + prompt: list[int] | list[list[int]] | str | list[str] | None = None + best_of: int | None = None + echo: bool | None = False + frequency_penalty: float | None = 0.0 + logit_bias: dict[str, float] | None = None + logprobs: int | None = None + max_tokens: int | None = 16 n: int = 1 - presence_penalty: Optional[float] = 0.0 - seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - stop: Optional[Union[str, list[str]]] = [] - stream: Optional[bool] = False - stream_options: Optional[StreamOptions] = None - suffix: Optional[str] = None - temperature: Optional[float] = None - top_p: Optional[float] = None - user: Optional[str] = None + presence_penalty: float | None = 0.0 + seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + stop: str | list[str] | None = [] + stream: bool | None = False + stream_options: StreamOptions | None = None + suffix: str | None = None + temperature: float | None = None + top_p: float | None = None + user: str | None = None # --8<-- [start:completion-sampling-params] use_beam_search: bool = False - top_k: Optional[int] = None - min_p: Optional[float] = None - repetition_penalty: Optional[float] = None + top_k: int | None = None + min_p: float | None = None + repetition_penalty: float | None = None length_penalty: float = 1.0 - stop_token_ids: Optional[list[int]] = [] + stop_token_ids: list[int] | None = [] include_stop_str_in_output: bool = False ignore_eos: bool = False min_tokens: int = 0 skip_special_tokens: bool = True spaces_between_special_tokens: bool = True - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None - allowed_token_ids: Optional[list[int]] = None - prompt_logprobs: Optional[int] = None + truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None + allowed_token_ids: list[int] | None = None + prompt_logprobs: int | None = None # --8<-- [end:completion-sampling-params] # --8<-- [start:completion-extra-params] - prompt_embeds: Optional[Union[bytes, list[bytes]]] = None + prompt_embeds: bytes | list[bytes] | None = None add_special_tokens: bool = Field( default=True, description=( @@ -1143,7 +1136,7 @@ class CompletionRequest(OpenAIBaseModel): "the prompt." ), ) - response_format: Optional[AnyResponseFormat] = Field( + response_format: AnyResponseFormat | None = Field( default=None, description=( "Similar to chat completion, this parameter specifies the format " @@ -1151,11 +1144,11 @@ class CompletionRequest(OpenAIBaseModel): ", {'type': 'structural_tag'}, or {'type': 'text' } is supported." ), ) - structured_outputs: Optional[StructuredOutputsParams] = Field( + structured_outputs: StructuredOutputsParams | None = Field( default=None, description="Additional kwargs for structured outputs", ) - guided_json: Optional[Union[str, dict, BaseModel]] = Field( + guided_json: str | dict | BaseModel | None = Field( default=None, description=( "`guided_json` is deprecated. " @@ -1163,7 +1156,7 @@ class CompletionRequest(OpenAIBaseModel): "Please pass `json` to `structured_outputs` instead." ), ) - guided_regex: Optional[str] = Field( + guided_regex: str | None = Field( default=None, description=( "`guided_regex` is deprecated. " @@ -1171,7 +1164,7 @@ class CompletionRequest(OpenAIBaseModel): "Please pass `regex` to `structured_outputs` instead." ), ) - guided_choice: Optional[list[str]] = Field( + guided_choice: list[str] | None = Field( default=None, description=( "`guided_choice` is deprecated. " @@ -1179,7 +1172,7 @@ class CompletionRequest(OpenAIBaseModel): "Please pass `choice` to `structured_outputs` instead." ), ) - guided_grammar: Optional[str] = Field( + guided_grammar: str | None = Field( default=None, description=( "`guided_grammar` is deprecated. " @@ -1187,7 +1180,11 @@ class CompletionRequest(OpenAIBaseModel): "Please pass `grammar` to `structured_outputs` instead." ), ) - guided_decoding_backend: Optional[str] = Field( + structural_tag: str | None = Field( + default=None, + description=("If specified, the output will follow the structural tag schema."), + ) + guided_decoding_backend: str | None = Field( default=None, description=( "`guided_decoding_backend` is deprecated. " @@ -1195,7 +1192,7 @@ class CompletionRequest(OpenAIBaseModel): "Please remove it from your request." ), ) - guided_whitespace_pattern: Optional[str] = Field( + guided_whitespace_pattern: str | None = Field( default=None, description=( "`guided_whitespace_pattern` is deprecated. " @@ -1219,7 +1216,7 @@ class CompletionRequest(OpenAIBaseModel): "through out the inference process and return in response." ), ) - logits_processors: Optional[LogitsProcessors] = Field( + logits_processors: LogitsProcessors | None = Field( default=None, description=( "A list of either qualified names of logits processors, or " @@ -1233,7 +1230,7 @@ class CompletionRequest(OpenAIBaseModel): ), ) - return_tokens_as_token_ids: Optional[bool] = Field( + return_tokens_as_token_ids: bool | None = Field( default=None, description=( "If specified with 'logprobs', tokens are represented " @@ -1241,7 +1238,7 @@ class CompletionRequest(OpenAIBaseModel): "that are not JSON-encodable can be identified." ), ) - return_token_ids: Optional[bool] = Field( + return_token_ids: bool | None = Field( default=None, description=( "If specified, the result will include token IDs alongside the " @@ -1252,7 +1249,7 @@ class CompletionRequest(OpenAIBaseModel): ), ) - cache_salt: Optional[str] = Field( + cache_salt: str | None = Field( default=None, description=( "If specified, the prefix cache will be salted with the provided " @@ -1264,12 +1261,12 @@ class CompletionRequest(OpenAIBaseModel): ), ) - kv_transfer_params: Optional[dict[str, Any]] = Field( + kv_transfer_params: dict[str, Any] | None = Field( default=None, description="KVTransfer parameters used for disaggregated serving.", ) - vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( + vllm_xargs: dict[str, str | int | float] | None = Field( default=None, description=( "Additional request parameters with string or " @@ -1291,7 +1288,7 @@ class CompletionRequest(OpenAIBaseModel): def to_beam_search_params( self, max_tokens: int, - default_sampling_params: Optional[dict] = None, + default_sampling_params: dict | None = None, ) -> BeamSearchParams: if default_sampling_params is None: default_sampling_params = {} @@ -1312,8 +1309,8 @@ def to_beam_search_params( def to_sampling_params( self, max_tokens: int, - logits_processor_pattern: Optional[str], - default_sampling_params: Optional[dict] = None, + logits_processor_pattern: str | None, + default_sampling_params: dict | None = None, ) -> SamplingParams: if default_sampling_params is None: default_sampling_params = {} @@ -1347,10 +1344,27 @@ def to_sampling_params( echo_without_generation = self.echo and self.max_tokens == 0 + guided_json_object = None + if self.response_format is not None: + if self.response_format.type == "json_object": + guided_json_object = True + elif self.response_format.type == "json_schema": + json_schema = self.response_format.json_schema + assert json_schema is not None + self.guided_json = json_schema.json_schema + elif self.response_format.type == "structural_tag": + structural_tag = self.response_format + assert structural_tag is not None and isinstance( + structural_tag, StructuralTagResponseFormat + ) + s_tag_obj = structural_tag.model_dump(by_alias=True) + self.structural_tag = json.dumps(s_tag_obj) + # Forward deprecated guided_* parameters to structured_outputs if self.structured_outputs is None: kwargs = dict[str, Any]( json=self.guided_json, + json_object=guided_json_object, regex=self.guided_regex, choice=self.guided_choice, grammar=self.guided_grammar, @@ -1360,13 +1374,6 @@ def to_sampling_params( if len(kwargs) > 0: self.structured_outputs = StructuredOutputsParams(**kwargs) - if ( - self.structured_outputs is not None - and self.response_format is not None - and self.response_format.type == "json_object" - ): - self.structured_outputs.json_object = True - extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} if self.kv_transfer_params: # Pass in kv_transfer_params via extra_args @@ -1488,12 +1495,12 @@ def check_cache_salt_support(cls, data): class EmbeddingCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/embeddings - model: Optional[str] = None - input: Union[list[int], list[list[int]], str, list[str]] - encoding_format: Literal["float", "base64"] = "float" - dimensions: Optional[int] = None - user: Optional[str] = None - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None + model: str | None = None + input: list[int] | list[list[int]] | str | list[str] + encoding_format: EncodingFormat = "float" + dimensions: int | None = None + user: str | None = None + truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None # --8<-- [start:embedding-extra-params] add_special_tokens: bool = Field( @@ -1519,8 +1526,26 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): "through out the inference process and return in response." ), ) - normalize: Optional[bool] = None - + normalize: bool | None = Field( + default=None, + description="Whether to normalize the embeddings outputs. Default is True.", + ) + embed_dtype: EmbedDType = Field( + default="float32", + description=( + "What dtype to use for encoding. Default to using float32 for base64 " + "encoding to match the OpenAI python client behavior. " + "This parameter will affect base64 and binary_response." + ), + ) + endianness: Endianness = Field( + default="native", + description=( + "What endianness to use for encoding. Default to using native for " + "base64 encoding to match the OpenAI python client behavior." + "This parameter will affect base64 and binary_response." + ), + ) # --8<-- [end:embedding-extra-params] def to_pooling_params(self): @@ -1532,13 +1557,13 @@ def to_pooling_params(self): class EmbeddingChatRequest(OpenAIBaseModel): - model: Optional[str] = None + model: str | None = None messages: list[ChatCompletionMessageParam] - encoding_format: Literal["float", "base64"] = "float" - dimensions: Optional[int] = None - user: Optional[str] = None - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None + encoding_format: EncodingFormat = "float" + dimensions: int | None = None + user: str | None = None + truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None # --8<-- [start:chat-embedding-extra-params] add_generation_prompt: bool = Field( @@ -1560,7 +1585,7 @@ class EmbeddingChatRequest(OpenAIBaseModel): "default)." ), ) - chat_template: Optional[str] = Field( + chat_template: str | None = Field( default=None, description=( "A Jinja template to use for this conversion. " @@ -1569,14 +1594,14 @@ class EmbeddingChatRequest(OpenAIBaseModel): "does not define one." ), ) - chat_template_kwargs: Optional[dict[str, Any]] = Field( + chat_template_kwargs: dict[str, Any] | None = Field( default=None, description=( "Additional keyword args to pass to the template renderer. " "Will be accessible by the chat template." ), ) - mm_processor_kwargs: Optional[dict[str, Any]] = Field( + mm_processor_kwargs: dict[str, Any] | None = Field( default=None, description=("Additional kwargs to pass to the HF processor."), ) @@ -1596,7 +1621,26 @@ class EmbeddingChatRequest(OpenAIBaseModel): "through out the inference process and return in response." ), ) - normalize: Optional[bool] = None + normalize: bool | None = Field( + default=None, + description="Whether to normalize the embeddings outputs. Default is True.", + ) + embed_dtype: EmbedDType = Field( + default="float32", + description=( + "What dtype to use for encoding. Default to using float32 for base64 " + "encoding to match the OpenAI python client behavior. " + "This parameter will affect base64 and binary_response." + ), + ) + endianness: Endianness = Field( + default="native", + description=( + "What endianness to use for encoding. Default to using native for " + "base64 encoding to match the OpenAI python client behavior." + "This parameter will affect base64 and binary_response." + ), + ) # --8<-- [end:chat-embedding-extra-params] @model_validator(mode="before") @@ -1617,7 +1661,7 @@ def to_pooling_params(self): ) -EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] +EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest PoolingCompletionRequest = EmbeddingCompletionRequest PoolingChatRequest = EmbeddingChatRequest @@ -1626,7 +1670,7 @@ def to_pooling_params(self): class IOProcessorRequest(OpenAIBaseModel, Generic[T]): - model: Optional[str] = None + model: str | None = None priority: int = Field(default=0) """ @@ -1635,18 +1679,31 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]): if the served model does not use priority scheduling. """ data: T - """ - When using plugins IOProcessor plugins, the actual input is processed - by the plugin itself. Hence, we use a generic type for the request data - """ - softmax: bool = True + + encoding_format: EncodingFormat = "float" + embed_dtype: EmbedDType = Field( + default="float32", + description=( + "What dtype to use for encoding. Default to using float32 for base64 " + "encoding to match the OpenAI python client behavior. " + "This parameter will affect base64 and binary_response." + ), + ) + endianness: Endianness = Field( + default="native", + description=( + "What endianness to use for encoding. Default to using native for " + "base64 encoding to match the OpenAI python client behavior." + "This parameter will affect base64 and binary_response." + ), + ) def to_pooling_params(self): - return PoolingParams(task="encode", softmax=self.softmax) + return PoolingParams() class IOProcessorResponse(OpenAIBaseModel, Generic[T]): - request_id: Optional[str] = None + request_id: str | None = None """ The request_id associated with this response """ @@ -1659,18 +1716,20 @@ class IOProcessorResponse(OpenAIBaseModel, Generic[T]): """ -PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest, IOProcessorRequest] +PoolingRequest: TypeAlias = ( + PoolingCompletionRequest | PoolingChatRequest | IOProcessorRequest +) class ScoreRequest(OpenAIBaseModel): - model: Optional[str] = None - text_1: Union[list[str], str, ScoreMultiModalParam] - text_2: Union[list[str], str, ScoreMultiModalParam] - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None + model: str | None = None + text_1: list[str] | str | ScoreMultiModalParam + text_2: list[str] | str | ScoreMultiModalParam + truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None # --8<-- [start:score-extra-params] - mm_processor_kwargs: Optional[dict[str, Any]] = Field( + mm_processor_kwargs: dict[str, Any] | None = Field( default=None, description=("Additional kwargs to pass to the HF processor."), ) @@ -1684,7 +1743,7 @@ class ScoreRequest(OpenAIBaseModel): ), ) - activation: Optional[bool] = None + activation: bool | None = None # --8<-- [end:score-extra-params] @@ -1696,15 +1755,15 @@ def to_pooling_params(self): class RerankRequest(OpenAIBaseModel): - model: Optional[str] = None - query: Union[str, ScoreMultiModalParam] - documents: Union[list[str], ScoreMultiModalParam] + model: str | None = None + query: str | ScoreMultiModalParam + documents: list[str] | ScoreMultiModalParam top_n: int = Field(default_factory=lambda: 0) - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None + truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None # --8<-- [start:rerank-extra-params] - mm_processor_kwargs: Optional[dict[str, Any]] = Field( + mm_processor_kwargs: dict[str, Any] | None = Field( default=None, description=("Additional kwargs to pass to the HF processor."), ) @@ -1718,7 +1777,7 @@ class RerankRequest(OpenAIBaseModel): ), ) - activation: Optional[bool] = None + activation: bool | None = None # --8<-- [end:rerank-extra-params] @@ -1730,8 +1789,8 @@ def to_pooling_params(self): class RerankDocument(BaseModel): - text: Optional[str] = None - multi_modal: Optional[ScoreContentPartParam] = None + text: str | None = None + multi_modal: ScoreContentPartParam | None = None class RerankResult(BaseModel): @@ -1753,17 +1812,17 @@ class RerankResponse(OpenAIBaseModel): class CompletionLogProbs(OpenAIBaseModel): text_offset: list[int] = Field(default_factory=list) - token_logprobs: list[Optional[float]] = Field(default_factory=list) + token_logprobs: list[float | None] = Field(default_factory=list) tokens: list[str] = Field(default_factory=list) - top_logprobs: list[Optional[dict[str, float]]] = Field(default_factory=list) + top_logprobs: list[dict[str, float] | None] = Field(default_factory=list) class CompletionResponseChoice(OpenAIBaseModel): index: int text: str - logprobs: Optional[CompletionLogProbs] = None - finish_reason: Optional[str] = None - stop_reason: Optional[Union[int, str]] = Field( + logprobs: CompletionLogProbs | None = None + finish_reason: str | None = None + stop_reason: int | str | None = Field( default=None, description=( "The stop string or token id that caused the completion " @@ -1771,9 +1830,9 @@ class CompletionResponseChoice(OpenAIBaseModel): "including encountering the EOS token" ), ) - token_ids: Optional[list[int]] = None # For response - prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None - prompt_token_ids: Optional[list[int]] = None # For prompt + token_ids: list[int] | None = None # For response + prompt_logprobs: list[dict[int, Logprob] | None] | None = None + prompt_token_ids: list[int] | None = None # For prompt class CompletionResponse(OpenAIBaseModel): @@ -1782,14 +1841,12 @@ class CompletionResponse(OpenAIBaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: list[CompletionResponseChoice] - service_tier: Optional[Literal["auto", "default", "flex", "scale", "priority"]] = ( - None - ) - system_fingerprint: Optional[str] = None + service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None + system_fingerprint: str | None = None usage: UsageInfo # vLLM-specific fields that are not in OpenAI spec - kv_transfer_params: Optional[dict[str, Any]] = Field( + kv_transfer_params: dict[str, Any] | None = Field( default=None, description="KVTransfer parameters." ) @@ -1797,9 +1854,9 @@ class CompletionResponse(OpenAIBaseModel): class CompletionResponseStreamChoice(OpenAIBaseModel): index: int text: str - logprobs: Optional[CompletionLogProbs] = None - finish_reason: Optional[str] = None - stop_reason: Optional[Union[int, str]] = Field( + logprobs: CompletionLogProbs | None = None + finish_reason: str | None = None + stop_reason: int | str | None = Field( default=None, description=( "The stop string or token id that caused the completion " @@ -1809,8 +1866,8 @@ class CompletionResponseStreamChoice(OpenAIBaseModel): ) # not part of the OpenAI spec but for tracing the tokens # prompt tokens is put into choice to align with CompletionResponseChoice - prompt_token_ids: Optional[list[int]] = None - token_ids: Optional[list[int]] = None + prompt_token_ids: list[int] | None = None + token_ids: list[int] | None = None class CompletionStreamResponse(OpenAIBaseModel): @@ -1819,13 +1876,13 @@ class CompletionStreamResponse(OpenAIBaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: list[CompletionResponseStreamChoice] - usage: Optional[UsageInfo] = Field(default=None) + usage: UsageInfo | None = Field(default=None) class EmbeddingResponseData(OpenAIBaseModel): index: int object: str = "embedding" - embedding: Union[list[float], str] + embedding: list[float] | str class EmbeddingResponse(OpenAIBaseModel): @@ -1837,10 +1894,16 @@ class EmbeddingResponse(OpenAIBaseModel): usage: UsageInfo +class EmbeddingBytesResponse(OpenAIBaseModel): + body: list[bytes] + metadata: str + media_type: str = "application/octet-stream" + + class PoolingResponseData(OpenAIBaseModel): index: int object: str = "pooling" - data: Union[list[list[float]], list[float], str] + data: list[list[float]] | list[float] | str class PoolingResponse(OpenAIBaseModel): @@ -1852,6 +1915,12 @@ class PoolingResponse(OpenAIBaseModel): usage: UsageInfo +class PoolingBytesResponse(OpenAIBaseModel): + body: list[bytes] + metadata: str + media_type: str = "application/octet-stream" + + class ScoreResponseData(OpenAIBaseModel): index: int object: str = "score" @@ -1868,10 +1937,10 @@ class ScoreResponse(OpenAIBaseModel): class ClassificationRequest(OpenAIBaseModel): - model: Optional[str] = None - input: Union[list[str], str] - truncate_prompt_tokens: Optional[int] = None - user: Optional[str] = None + model: str | None = None + input: list[str] | str + truncate_prompt_tokens: int | None = None + user: str | None = None # --8<-- [start:classification-extra-params] priority: int = Field( @@ -1883,7 +1952,7 @@ class ClassificationRequest(OpenAIBaseModel): ), ) - activation: Optional[bool] = None + activation: bool | None = None # --8<-- [end:classification-extra-params] @@ -1896,7 +1965,7 @@ def to_pooling_params(self): class ClassificationData(OpenAIBaseModel): index: int - label: Optional[str] + label: str | None probs: list[float] num_classes: int @@ -1922,16 +1991,16 @@ class ToolCall(OpenAIBaseModel): class DeltaFunctionCall(BaseModel): - name: Optional[str] = None - arguments: Optional[str] = None + name: str | None = None + arguments: str | None = None # a tool call delta where everything is optional class DeltaToolCall(OpenAIBaseModel): - id: Optional[str] = None - type: Optional[Literal["function"]] = None + id: str | None = None + type: Literal["function"] | None = None index: int - function: Optional[DeltaFunctionCall] = None + function: DeltaFunctionCall | None = None class ExtractedToolCallInformation(BaseModel): @@ -1943,50 +2012,50 @@ class ExtractedToolCallInformation(BaseModel): # content - per OpenAI spec, content AND tool calls can be returned rarely # But some models will do this intentionally - content: Optional[str] = None + content: str | None = None class ChatMessage(OpenAIBaseModel): role: str - content: Optional[str] = None - refusal: Optional[str] = None - annotations: Optional[OpenAIAnnotation] = None - audio: Optional[OpenAIChatCompletionAudio] = None - function_call: Optional[FunctionCall] = None + content: str | None = None + refusal: str | None = None + annotations: OpenAIAnnotation | None = None + audio: OpenAIChatCompletionAudio | None = None + function_call: FunctionCall | None = None tool_calls: list[ToolCall] = Field(default_factory=list) # vLLM-specific fields that are not in OpenAI spec - reasoning_content: Optional[str] = None + reasoning_content: str | None = None class ChatCompletionLogProb(OpenAIBaseModel): token: str logprob: float = -9999.0 - bytes: Optional[list[int]] = None + bytes: list[int] | None = None class ChatCompletionLogProbsContent(ChatCompletionLogProb): # Workaround: redefine fields name cache so that it's not # shared with the super class. - field_names: ClassVar[Optional[set[str]]] = None + field_names: ClassVar[set[str] | None] = None top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list) class ChatCompletionLogProbs(OpenAIBaseModel): - content: Optional[list[ChatCompletionLogProbsContent]] = None + content: list[ChatCompletionLogProbsContent] | None = None class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage - logprobs: Optional[ChatCompletionLogProbs] = None + logprobs: ChatCompletionLogProbs | None = None # per OpenAI spec this is the default - finish_reason: Optional[str] = "stop" + finish_reason: str | None = "stop" # not part of the OpenAI spec but included in vLLM for legacy reasons - stop_reason: Optional[Union[int, str]] = None + stop_reason: int | str | None = None # not part of the OpenAI spec but is useful for tracing the tokens # in agent scenarios - token_ids: Optional[list[int]] = None + token_ids: list[int] | None = None class ChatCompletionResponse(OpenAIBaseModel): @@ -1995,35 +2064,33 @@ class ChatCompletionResponse(OpenAIBaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: list[ChatCompletionResponseChoice] - service_tier: Optional[Literal["auto", "default", "flex", "scale", "priority"]] = ( - None - ) - system_fingerprint: Optional[str] = None + service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None + system_fingerprint: str | None = None usage: UsageInfo # vLLM-specific fields that are not in OpenAI spec - prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None - prompt_token_ids: Optional[list[int]] = None - kv_transfer_params: Optional[dict[str, Any]] = Field( + prompt_logprobs: list[dict[int, Logprob] | None] | None = None + prompt_token_ids: list[int] | None = None + kv_transfer_params: dict[str, Any] | None = Field( default=None, description="KVTransfer parameters." ) class DeltaMessage(OpenAIBaseModel): - role: Optional[str] = None - content: Optional[str] = None - reasoning_content: Optional[str] = None + role: str | None = None + content: str | None = None + reasoning_content: str | None = None tool_calls: list[DeltaToolCall] = Field(default_factory=list) class ChatCompletionResponseStreamChoice(OpenAIBaseModel): index: int delta: DeltaMessage - logprobs: Optional[ChatCompletionLogProbs] = None - finish_reason: Optional[str] = None - stop_reason: Optional[Union[int, str]] = None + logprobs: ChatCompletionLogProbs | None = None + finish_reason: str | None = None + stop_reason: int | str | None = None # not part of the OpenAI spec but for tracing the tokens - token_ids: Optional[list[int]] = None + token_ids: list[int] | None = None class ChatCompletionStreamResponse(OpenAIBaseModel): @@ -2032,15 +2099,15 @@ class ChatCompletionStreamResponse(OpenAIBaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: list[ChatCompletionResponseStreamChoice] - usage: Optional[UsageInfo] = Field(default=None) + usage: UsageInfo | None = Field(default=None) # not part of the OpenAI spec but for tracing the tokens - prompt_token_ids: Optional[list[int]] = None + prompt_token_ids: list[int] | None = None class TranscriptionResponseStreamChoice(OpenAIBaseModel): delta: DeltaMessage - finish_reason: Optional[str] = None - stop_reason: Optional[Union[int, str]] = None + finish_reason: str | None = None + stop_reason: int | str | None = None class TranscriptionStreamResponse(OpenAIBaseModel): @@ -2049,16 +2116,20 @@ class TranscriptionStreamResponse(OpenAIBaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: list[TranscriptionResponseStreamChoice] - usage: Optional[UsageInfo] = Field(default=None) + usage: UsageInfo | None = Field(default=None) class InputTokensDetails(OpenAIBaseModel): cached_tokens: int + input_tokens_per_turn: list[int] = Field(default_factory=list) + cached_tokens_per_turn: list[int] = Field(default_factory=list) class OutputTokensDetails(OpenAIBaseModel): reasoning_tokens: int = 0 tool_output_tokens: int = 0 + output_tokens_per_turn: list[int] = Field(default_factory=list) + tool_output_tokens_per_turn: list[int] = Field(default_factory=list) class ResponseUsage(OpenAIBaseModel): @@ -2069,13 +2140,33 @@ class ResponseUsage(OpenAIBaseModel): total_tokens: int +def serialize_message(msg): + """ + Serializes a single message + """ + if isinstance(msg, dict): + return msg + elif hasattr(msg, "to_dict"): + return msg.to_dict() + else: + # fallback to pyandic dump + return msg.model_dump_json() + + +def serialize_messages(msgs): + """ + Serializes multiple messages + """ + return [serialize_message(msg) for msg in msgs] if msgs else None + + class ResponsesResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"resp_{random_uuid()}") created_at: int = Field(default_factory=lambda: int(time.time())) # error: Optional[ResponseError] = None - incomplete_details: Optional[IncompleteDetails] = None - instructions: Optional[str] = None - metadata: Optional[Metadata] = None + incomplete_details: IncompleteDetails | None = None + instructions: str | None = None + metadata: Metadata | None = None model: str object: Literal["response"] = "response" output: list[ResponseOutputItem] @@ -2086,24 +2177,24 @@ class ResponsesResponse(OpenAIBaseModel): top_p: float background: bool max_output_tokens: int - max_tool_calls: Optional[int] = None - previous_response_id: Optional[str] = None - prompt: Optional[ResponsePrompt] = None - reasoning: Optional[Reasoning] = None + max_tool_calls: int | None = None + previous_response_id: str | None = None + prompt: ResponsePrompt | None = None + reasoning: Reasoning | None = None service_tier: Literal["auto", "default", "flex", "scale", "priority"] status: ResponseStatus - text: Optional[ResponseTextConfig] = None - top_logprobs: Optional[int] = None + text: ResponseTextConfig | None = None + top_logprobs: int | None = None truncation: Literal["auto", "disabled"] - usage: Optional[ResponseUsage] = None - user: Optional[str] = None + usage: ResponseUsage | None = None + user: str | None = None # --8<-- [start:responses-extra-params] # These are populated when enable_response_messages is set to True # NOTE: custom serialization is needed # see serialize_input_messages and serialize_output_messages - input_messages: Optional[list[ChatCompletionMessageParam]] = None - output_messages: Optional[list[ChatCompletionMessageParam]] = None + input_messages: list[ChatCompletionMessageParam] | None = None + output_messages: list[ChatCompletionMessageParam] | None = None # --8<-- [end:responses-extra-params] # NOTE: openAI harmony doesn't serialize TextContent properly, @@ -2111,35 +2202,13 @@ class ResponsesResponse(OpenAIBaseModel): # https://github.com/openai/harmony/issues/78 @field_serializer("output_messages", when_used="json") def serialize_output_messages(self, msgs, _info): - if msgs: - serialized = [] - for m in msgs: - if isinstance(m, dict): - serialized.append(m) - elif hasattr(m, "__dict__"): - serialized.append(m.to_dict()) - else: - # fallback to pyandic dump - serialized.append(m.model_dump_json()) - return serialized - return None + return serialize_messages(msgs) # NOTE: openAI harmony doesn't serialize TextContent properly, this fixes it # https://github.com/openai/harmony/issues/78 @field_serializer("input_messages", when_used="json") def serialize_input_messages(self, msgs, _info): - if msgs: - serialized = [] - for m in msgs: - if isinstance(m, dict): - serialized.append(m) - elif hasattr(m, "__dict__"): - serialized.append(m.to_dict()) - else: - # fallback to pyandic dump - serialized.append(m.model_dump_json()) - return serialized - return None + return serialize_messages(msgs) @classmethod def from_request( @@ -2150,11 +2219,11 @@ def from_request( created_time: int, output: list[ResponseOutputItem], status: ResponseStatus, - usage: Optional[ResponseUsage] = None, - input_messages: Optional[list[ChatCompletionMessageParam]] = None, - output_messages: Optional[list[ChatCompletionMessageParam]] = None, + usage: ResponseUsage | None = None, + input_messages: list[ChatCompletionMessageParam] | None = None, + output_messages: list[ChatCompletionMessageParam] | None = None, ) -> "ResponsesResponse": - incomplete_details: Optional[IncompleteDetails] = None + incomplete_details: IncompleteDetails | None = None if status == "incomplete": incomplete_details = IncompleteDetails(reason="max_output_tokens") # TODO: implement the other reason for incomplete_details, @@ -2249,31 +2318,31 @@ class ResponseInProgressEvent(OpenAIResponseInProgressEvent): response: ResponsesResponse # type: ignore[override] -StreamingResponsesResponse: TypeAlias = Union[ - "ResponseCreatedEvent", - "ResponseInProgressEvent", - "ResponseCompletedEvent", - ResponseOutputItemAddedEvent, - ResponseOutputItemDoneEvent, - ResponseContentPartAddedEvent, - ResponseContentPartDoneEvent, - ResponseReasoningTextDeltaEvent, - ResponseReasoningTextDoneEvent, - ResponseReasoningPartAddedEvent, - ResponseReasoningPartDoneEvent, - ResponseCodeInterpreterCallInProgressEvent, - ResponseCodeInterpreterCallCodeDeltaEvent, - ResponseWebSearchCallInProgressEvent, - ResponseWebSearchCallSearchingEvent, - ResponseWebSearchCallCompletedEvent, - ResponseCodeInterpreterCallCodeDoneEvent, - ResponseCodeInterpreterCallInterpretingEvent, - ResponseCodeInterpreterCallCompletedEvent, -] +StreamingResponsesResponse: TypeAlias = ( + ResponseCreatedEvent + | ResponseInProgressEvent + | ResponseCompletedEvent + | ResponseOutputItemAddedEvent + | ResponseOutputItemDoneEvent + | ResponseContentPartAddedEvent + | ResponseContentPartDoneEvent + | ResponseReasoningTextDeltaEvent + | ResponseReasoningTextDoneEvent + | ResponseReasoningPartAddedEvent + | ResponseReasoningPartDoneEvent + | ResponseCodeInterpreterCallInProgressEvent + | ResponseCodeInterpreterCallCodeDeltaEvent + | ResponseWebSearchCallInProgressEvent + | ResponseWebSearchCallSearchingEvent + | ResponseWebSearchCallCompletedEvent + | ResponseCodeInterpreterCallCodeDoneEvent + | ResponseCodeInterpreterCallInterpretingEvent + | ResponseCodeInterpreterCallCompletedEvent +) -BatchRequestInputBody = Union[ - ChatCompletionRequest, EmbeddingRequest, ScoreRequest, RerankRequest -] +BatchRequestInputBody: TypeAlias = ( + ChatCompletionRequest | EmbeddingRequest | ScoreRequest | RerankRequest +) class BatchRequestInput(OpenAIBaseModel): @@ -2322,9 +2391,13 @@ class BatchResponseData(OpenAIBaseModel): request_id: str # The body of the response. - body: Optional[ - Union[ChatCompletionResponse, EmbeddingResponse, ScoreResponse, RerankResponse] - ] = None + body: ( + ChatCompletionResponse + | EmbeddingResponse + | ScoreResponse + | RerankResponse + | None + ) = None class BatchRequestOutput(OpenAIBaseModel): @@ -2338,15 +2411,15 @@ class BatchRequestOutput(OpenAIBaseModel): # inputs. custom_id: str - response: Optional[BatchResponseData] + response: BatchResponseData | None # For requests that failed with a non-HTTP error, this will contain more # information on the cause of the failure. - error: Optional[Any] + error: Any | None class TokenizeCompletionRequest(OpenAIBaseModel): - model: Optional[str] = None + model: str | None = None prompt: str add_special_tokens: bool = Field( @@ -2356,7 +2429,7 @@ class TokenizeCompletionRequest(OpenAIBaseModel): "the prompt." ), ) - return_token_strs: Optional[bool] = Field( + return_token_strs: bool | None = Field( default=False, description=( "If true, also return the token strings corresponding to the token ids." @@ -2365,7 +2438,7 @@ class TokenizeCompletionRequest(OpenAIBaseModel): class TokenizeChatRequest(OpenAIBaseModel): - model: Optional[str] = None + model: str | None = None messages: list[ChatCompletionMessageParam] add_generation_prompt: bool = Field( @@ -2376,7 +2449,7 @@ class TokenizeChatRequest(OpenAIBaseModel): "model." ), ) - return_token_strs: Optional[bool] = Field( + return_token_strs: bool | None = Field( default=False, description=( "If true, also return the token strings corresponding to the token ids." @@ -2402,7 +2475,7 @@ class TokenizeChatRequest(OpenAIBaseModel): "default)." ), ) - chat_template: Optional[str] = Field( + chat_template: str | None = Field( default=None, description=( "A Jinja template to use for this conversion. " @@ -2411,18 +2484,18 @@ class TokenizeChatRequest(OpenAIBaseModel): "does not define one." ), ) - chat_template_kwargs: Optional[dict[str, Any]] = Field( + chat_template_kwargs: dict[str, Any] | None = Field( default=None, description=( "Additional keyword args to pass to the template renderer. " "Will be accessible by the chat template." ), ) - mm_processor_kwargs: Optional[dict[str, Any]] = Field( + mm_processor_kwargs: dict[str, Any] | None = Field( default=None, description=("Additional kwargs to pass to the HF processor."), ) - tools: Optional[list[ChatCompletionToolsParam]] = Field( + tools: list[ChatCompletionToolsParam] | None = Field( default=None, description=("A list of tools the model may call."), ) @@ -2438,18 +2511,18 @@ def check_generation_prompt(cls, data): return data -TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest] +TokenizeRequest: TypeAlias = TokenizeCompletionRequest | TokenizeChatRequest class TokenizeResponse(OpenAIBaseModel): count: int max_model_len: int tokens: list[int] - token_strs: Optional[list[str]] = None + token_strs: list[str] | None = None class DetokenizeRequest(OpenAIBaseModel): - model: Optional[str] = None + model: str | None = None tokens: list[int] @@ -2474,7 +2547,7 @@ class LoadLoRAAdapterRequest(BaseModel): class UnloadLoRAAdapterRequest(BaseModel): lora_name: str - lora_int_id: Optional[int] = Field(default=None) + lora_int_id: int | None = Field(default=None) ## Protocols for Audio @@ -2491,11 +2564,11 @@ class TranscriptionRequest(OpenAIBaseModel): formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. """ - model: Optional[str] = None + model: str | None = None """ID of the model to use. """ - language: Optional[str] = None + language: str | None = None """The language of the input audio. Supplying the input language in @@ -2530,16 +2603,16 @@ class TranscriptionRequest(OpenAIBaseModel): timestamps incurs additional latency. """ - stream: Optional[bool] = False + stream: bool | None = False """When set, it will enable output to be streamed in a similar fashion as the Chat Completion endpoint. """ # --8<-- [start:transcription-extra-params] # Flattened stream option to simplify form data. - stream_include_usage: Optional[bool] = False - stream_continuous_usage_stats: Optional[bool] = False + stream_include_usage: bool | None = False + stream_continuous_usage_stats: bool | None = False - vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( + vllm_xargs: dict[str, str | int | float] | None = Field( default=None, description=( "Additional request parameters with string or " @@ -2548,7 +2621,7 @@ class TranscriptionRequest(OpenAIBaseModel): ) # --8<-- [end:transcription-extra-params] - to_language: Optional[str] = None + to_language: str | None = None """The language of the output audio we transcribe to. Please note that this is not currently used by supported models at this @@ -2565,29 +2638,29 @@ class TranscriptionRequest(OpenAIBaseModel): to automatically increase the temperature until certain thresholds are hit. """ - top_p: Optional[float] = None + top_p: float | None = None """Enables nucleus (top-p) sampling, where tokens are selected from the smallest possible set whose cumulative probability exceeds `p`. """ - top_k: Optional[int] = None + top_k: int | None = None """Limits sampling to the `k` most probable tokens at each step.""" - min_p: Optional[float] = None + min_p: float | None = None """Filters out tokens with a probability lower than `min_p`, ensuring a minimum likelihood threshold during sampling. """ - seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) """The seed to use for sampling.""" - frequency_penalty: Optional[float] = 0.0 + frequency_penalty: float | None = 0.0 """The frequency penalty to use for sampling.""" - repetition_penalty: Optional[float] = None + repetition_penalty: float | None = None """The repetition penalty to use for sampling.""" - presence_penalty: Optional[float] = 0.0 + presence_penalty: float | None = 0.0 """The presence penalty to use for sampling.""" # --8<-- [end:transcription-sampling-params] @@ -2601,7 +2674,7 @@ class TranscriptionRequest(OpenAIBaseModel): } def to_sampling_params( - self, default_max_tokens: int, default_sampling_params: Optional[dict] = None + self, default_max_tokens: int, default_sampling_params: dict | None = None ) -> SamplingParams: max_tokens = default_max_tokens @@ -2740,17 +2813,17 @@ class TranscriptionResponseVerbose(OpenAIBaseModel): text: str """The transcribed text.""" - segments: Optional[list[TranscriptionSegment]] = None + segments: list[TranscriptionSegment] | None = None """Segments of the transcribed text and their corresponding details.""" - words: Optional[list[TranscriptionWord]] = None + words: list[TranscriptionWord] | None = None """Extracted words and their corresponding timestamps.""" class TranslationResponseStreamChoice(OpenAIBaseModel): delta: DeltaMessage - finish_reason: Optional[str] = None - stop_reason: Optional[Union[int, str]] = None + finish_reason: str | None = None + stop_reason: int | str | None = None class TranslationStreamResponse(OpenAIBaseModel): @@ -2759,7 +2832,7 @@ class TranslationStreamResponse(OpenAIBaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: list[TranslationResponseStreamChoice] - usage: Optional[UsageInfo] = Field(default=None) + usage: UsageInfo | None = Field(default=None) class TranslationRequest(OpenAIBaseModel): @@ -2772,7 +2845,7 @@ class TranslationRequest(OpenAIBaseModel): formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. """ - model: Optional[str] = None + model: str | None = None """ID of the model to use. """ @@ -2792,7 +2865,7 @@ class TranslationRequest(OpenAIBaseModel): # TODO support additional sampling parameters # --8<-- [start:translation-sampling-params] - seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) """The seed to use for sampling.""" temperature: float = Field(default=0.0) @@ -2806,7 +2879,7 @@ class TranslationRequest(OpenAIBaseModel): # --8<-- [end:translation-sampling-params] # --8<-- [start:translation-extra-params] - language: Optional[str] = None + language: str | None = None """The language of the input audio we translate from. Supplying the input language in @@ -2814,7 +2887,7 @@ class TranslationRequest(OpenAIBaseModel): will improve accuracy. """ - to_language: Optional[str] = None + to_language: str | None = None """The language of the input audio we translate to. Please note that this is not supported by all models, refer to the specific @@ -2822,14 +2895,14 @@ class TranslationRequest(OpenAIBaseModel): For instance, Whisper only supports `to_language=en`. """ - stream: Optional[bool] = False + stream: bool | None = False """Custom field not present in the original OpenAI definition. When set, it will enable output to be streamed in a similar fashion as the Chat Completion endpoint. """ # Flattened stream option to simplify form data. - stream_include_usage: Optional[bool] = False - stream_continuous_usage_stats: Optional[bool] = False + stream_include_usage: bool | None = False + stream_continuous_usage_stats: bool | None = False # --8<-- [end:translation-extra-params] # Default sampling parameters for translation requests. @@ -2838,7 +2911,7 @@ class TranslationRequest(OpenAIBaseModel): } def to_sampling_params( - self, default_max_tokens: int, default_sampling_params: Optional[dict] = None + self, default_max_tokens: int, default_sampling_params: dict | None = None ) -> SamplingParams: max_tokens = default_max_tokens @@ -2939,8 +3012,8 @@ class TranslationResponseVerbose(OpenAIBaseModel): text: str """The translated text.""" - segments: Optional[list[TranslationSegment]] = None + segments: list[TranslationSegment] | None = None """Segments of the translated text and their corresponding details.""" - words: Optional[list[TranslationWord]] = None + words: list[TranslationWord] | None = None """Extracted words and their corresponding timestamps.""" diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 030ce3ce0844..da036e30ba7e 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -4,17 +4,15 @@ import asyncio import tempfile from argparse import Namespace -from collections.abc import Awaitable +from collections.abc import Awaitable, Callable from http import HTTPStatus from io import StringIO -from typing import Callable, Optional import aiohttp import torch from prometheus_client import start_http_server from tqdm import tqdm -from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger @@ -33,6 +31,7 @@ from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.entrypoints.openai.serving_score import ServingScores from vllm.logger import init_logger +from vllm.reasoning import ReasoningParserManager from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.version import __version__ as VLLM_VERSION @@ -106,6 +105,13 @@ def make_arg_parser(parser: FlexibleArgumentParser): default=False, help="If set to True, enable prompt_tokens_details in usage.", ) + parser.add_argument( + "--enable-force-include-usage", + action="store_true", + default=False, + help="If set to True, include usage on every request " + "(even when stream_options is not specified)", + ) return parser @@ -125,7 +131,7 @@ def parse_args(): class BatchProgressTracker: def __init__(self): self._total = 0 - self._pbar: Optional[tqdm] = None + self._pbar: tqdm | None = None def submitted(self): self._total += 1 @@ -326,9 +332,19 @@ async def run_request( return batch_output +def validate_run_batch_args(args): + valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys() + if ( + reasoning_parser := args.structured_outputs_config.reasoning_parser + ) and reasoning_parser not in valid_reasoning_parses: + raise KeyError( + f"invalid reasoning parser: {reasoning_parser} " + f"(chose from {{ {','.join(valid_reasoning_parses)} }})" + ) + + async def run_batch( engine_client: EngineClient, - vllm_config: VllmConfig, args: Namespace, ) -> None: if args.served_model_name is not None: @@ -345,36 +361,36 @@ async def run_batch( BaseModelPath(name=name, model_path=args.model) for name in served_model_names ] - model_config = vllm_config.model_config - + model_config = engine_client.model_config supported_tasks = await engine_client.get_supported_tasks() - logger.info("Supported_tasks: %s", supported_tasks) + logger.info("Supported tasks: %s", supported_tasks) # Create the openai serving objects. openai_serving_models = OpenAIServingModels( engine_client=engine_client, - model_config=model_config, base_model_paths=base_model_paths, lora_modules=None, ) + openai_serving_chat = ( OpenAIServingChat( engine_client, - model_config, openai_serving_models, args.response_role, request_logger=request_logger, chat_template=None, chat_template_content_format="auto", + reasoning_parser=args.structured_outputs_config.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, ) if "generate" in supported_tasks else None ) + openai_serving_embedding = ( OpenAIServingEmbedding( engine_client, - model_config, openai_serving_models, request_logger=request_logger, chat_template=None, @@ -392,7 +408,6 @@ async def run_batch( openai_serving_scores = ( ServingScores( engine_client, - model_config, openai_serving_models, request_logger=request_logger, ) @@ -504,14 +519,14 @@ async def main(args: Namespace): from vllm.entrypoints.openai.api_server import build_async_engine_client from vllm.usage.usage_lib import UsageContext + validate_run_batch_args(args) + async with build_async_engine_client( args, usage_context=UsageContext.OPENAI_BATCH_RUNNER, disable_frontend_multiprocessing=False, ) as engine_client: - vllm_config = await engine_client.get_vllm_config() - - await run_batch(engine_client, vllm_config, args) + await run_batch(engine_client, args) if __name__ == "__main__": diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 12dd474936db..3bf887c659dc 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -6,7 +6,7 @@ import time from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import Sequence as GenericSequence -from typing import Callable, Final, Optional, Union +from typing import Final import jinja2 import partial_json_parser @@ -15,7 +15,6 @@ from openai_harmony import Message as OpenAIMessage from pydantic import TypeAdapter -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ( ChatTemplateContentFormatOption, @@ -57,14 +56,13 @@ ) from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall -from vllm.entrypoints.utils import get_max_tokens +from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.outputs import CompletionOutput, RequestOutput -from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizers import ( @@ -72,7 +70,7 @@ truncate_tool_call_ids, validate_request_params, ) -from vllm.utils import as_list +from vllm.utils.collection_utils import as_list logger = init_logger(__name__) @@ -81,19 +79,18 @@ class OpenAIServingChat(OpenAIServing): def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, response_role: str, *, - request_logger: Optional[RequestLogger], - chat_template: Optional[str], + request_logger: RequestLogger | None, + chat_template: str | None, chat_template_content_format: ChatTemplateContentFormatOption, trust_request_chat_template: bool = False, return_tokens_as_token_ids: bool = False, reasoning_parser: str = "", enable_auto_tools: bool = False, exclude_tools_when_tool_choice_none: bool = False, - tool_parser: Optional[str] = None, + tool_parser: str | None = None, enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, enable_log_outputs: bool = False, @@ -101,11 +98,9 @@ def __init__( ) -> None: super().__init__( engine_client=engine_client, - model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - enable_force_include_usage=enable_force_include_usage, log_error_stack=log_error_stack, ) @@ -115,42 +110,15 @@ def __init__( self.trust_request_chat_template = trust_request_chat_template self.enable_log_outputs = enable_log_outputs + # set up reasoning parser + self.reasoning_parser = self._get_reasoning_parser( + reasoning_parser_name=reasoning_parser + ) # set up tool use self.enable_auto_tools: bool = enable_auto_tools - if self.enable_auto_tools: - logger.info( - '"auto" tool choice has been enabled please note that while' - " the parallel_tool_calls client option is preset for " - "compatibility reasons, it will be ignored." - ) - - self.reasoning_parser: Optional[Callable[[AnyTokenizer], ReasoningParser]] = ( - None + self.tool_parser = self._get_tool_parser( + tool_parser_name=tool_parser, enable_auto_tools=enable_auto_tools ) - if reasoning_parser: - try: - self.reasoning_parser = ReasoningParserManager.get_reasoning_parser( - reasoning_parser - ) - assert self.reasoning_parser is not None - except Exception as e: - raise TypeError(f"{reasoning_parser=} has not been registered") from e - self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None - if self.enable_auto_tools: - try: - if tool_parser == "pythonic" and model_config.model.startswith( - "meta-llama/Llama-3.2" - ): - logger.warning( - "Llama3.2 models may struggle to emit valid pythonic tool calls" - ) - self.tool_parser = ToolParserManager.get_tool_parser(tool_parser) - except Exception as e: - raise TypeError( - "Error: --enable-auto-tool-choice requires " - f"tool_parser:'{tool_parser}' which has not " - "been registered" - ) from e self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none self.enable_prompt_tokens_details = enable_prompt_tokens_details @@ -169,7 +137,7 @@ def __init__( else: self.tool_call_id_type = "random" - self.use_harmony = model_config.hf_config.model_type == "gpt_oss" + self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss" if self.use_harmony: if "stop_token_ids" not in self.default_sampling_params: self.default_sampling_params["stop_token_ids"] = [] @@ -190,8 +158,8 @@ def __init__( async def create_chat_completion( self, request: ChatCompletionRequest, - raw_request: Optional[Request] = None, - ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, ErrorResponse]: + raw_request: Request | None = None, + ) -> AsyncGenerator[str, None] | ChatCompletionResponse | ErrorResponse: """ Chat Completion API similar to OpenAI's API. @@ -312,7 +280,7 @@ async def create_chat_completion( default_sampling_params=self.default_sampling_params, ) - sampling_params: Union[SamplingParams, BeamSearchParams] + sampling_params: SamplingParams | BeamSearchParams if request.use_beam_search: sampling_params = request.to_beam_search_params( max_tokens, self.default_sampling_params @@ -338,7 +306,7 @@ async def create_chat_completion( ) if isinstance(sampling_params, BeamSearchParams): - generator = self.engine_client.beam_search( + generator = self.beam_search( prompt=engine_prompt, request_id=request_id, params=sampling_params, @@ -383,7 +351,6 @@ async def create_chat_completion( conversation, tokenizer, request_metadata, - enable_force_include_usage=self.enable_force_include_usage, ) try: @@ -447,11 +414,11 @@ def _filter_delta_text(delta_text: str, previous_text: str) -> tuple[str, bool]: def extract_tool_call_required_streaming( self, previous_text: str, - current_text: Optional[str], + current_text: str | None, delta_text: str, function_name_returned: bool, - tool_call_idx: Optional[int] = None, - ) -> tuple[Optional[DeltaMessage], bool]: + tool_call_idx: int | None = None, + ) -> tuple[DeltaMessage | None, bool]: if current_text is None or current_text == "": # if the current text is empty, we cannot parse it return None, function_name_returned @@ -549,7 +516,6 @@ async def chat_completion_stream_generator( conversation: list[ConversationMessage], tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, - enable_force_include_usage: bool, ) -> AsyncGenerator[str, None]: created_time = int(time.time()) chunk_object_type: Final = "chat.completion.chunk" @@ -579,7 +545,7 @@ async def chat_completion_stream_generator( and self._should_stream_with_auto_tool_parsing(request) ) - all_previous_token_ids: Optional[list[list[int]]] + all_previous_token_ids: list[list[int]] | None function_name_returned = [False] * num_choices if self.tool_call_id_type == "kimi_k2": history_tool_call_cnt = get_history_tool_calls_cnt(conversation) @@ -597,14 +563,15 @@ async def chat_completion_stream_generator( # For reasoning parser and tool call all enabled added_content_delta_arr = [False] * num_choices reasoning_end_arr = [False] * num_choices - elif request.tool_choice == "required": - all_previous_token_ids = None else: all_previous_token_ids = None try: if self.reasoning_parser: - reasoning_parser = self.reasoning_parser(tokenizer) + reasoning_parser = self.reasoning_parser( + tokenizer, + chat_template_kwargs=request.chat_template_kwargs, # type: ignore + ) except RuntimeError as e: logger.exception("Error in reasoning parser creation.") data = self.create_streaming_error_response(str(e)) @@ -614,7 +581,7 @@ async def chat_completion_stream_generator( # Prepare the tool parser if it's needed try: if tool_choice_auto and self.tool_parser: - tool_parsers: list[Optional[ToolParser]] = [ + tool_parsers: list[ToolParser | None] = [ self.tool_parser(tokenizer) ] * num_choices else: @@ -627,13 +594,9 @@ async def chat_completion_stream_generator( return stream_options = request.stream_options - if stream_options: - include_usage = stream_options.include_usage or enable_force_include_usage - include_continuous_usage = ( - include_usage and stream_options.continuous_usage_stats - ) - else: - include_usage, include_continuous_usage = False, False + include_usage, include_continuous_usage = should_include_usage( + stream_options, self.enable_force_include_usage + ) try: async for res in result_generator: @@ -692,7 +655,7 @@ async def chat_completion_stream_generator( # Send response to echo the input portion of the # last message if request.echo: - last_msg_content: Union[str, list[dict[str, str]]] = "" + last_msg_content: str | list[dict[str, str]] = "" if ( conversation and "content" in conversation[-1] @@ -765,7 +728,7 @@ async def chat_completion_stream_generator( # Chunked prefill case, don't return empty chunks continue - delta_message: Optional[DeltaMessage] + delta_message: DeltaMessage | None # just update previous_texts and previous_token_ids if tool_choice_auto or self.reasoning_parser: @@ -915,29 +878,56 @@ async def chat_completion_stream_generator( previous_text = previous_texts[i] current_text = previous_text + delta_text fn_name_returned = function_name_returned[i] + output_token_ids = as_list(output.token_ids) + + if ( + self.reasoning_parser is not None + and not reasoning_end_arr[i] + and res.prompt_token_ids + and reasoning_parser.is_reasoning_end(res.prompt_token_ids) + ): + reasoning_end_arr[i] = True - if self.reasoning_parser: - _, content = reasoning_parser.extract_reasoning_content( - current_text, request + if self.reasoning_parser and not reasoning_end_arr[i]: + delta_message = ( + reasoning_parser.extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output_token_ids, + ) ) + if reasoning_parser.is_reasoning_end(output_token_ids): + reasoning_end_arr[i] = True + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + # reasoning ended + current_text = "" + else: + # either finished reasoning or no reasoning at all content = current_text - delta_message, function_name_returned[i] = ( - self.extract_tool_call_required_streaming( - previous_text=previous_text, - current_text=content, - delta_text=delta_text, - function_name_returned=fn_name_returned, - tool_call_idx=history_tool_call_cnt, + + delta_message, function_name_returned[i] = ( + self.extract_tool_call_required_streaming( + previous_text=previous_text, + current_text=content, + delta_text=delta_text, + function_name_returned=fn_name_returned, + tool_call_idx=history_tool_call_cnt, + ) ) - ) - if ( - delta_message - and delta_message.tool_calls - and delta_message.tool_calls[0].id is not None - ): - history_tool_call_cnt += 1 - tools_streamed[i] = True + if ( + delta_message + and delta_message.tool_calls + and delta_message.tool_calls[0].id is not None + ): + history_tool_call_cnt += 1 + tools_streamed[i] = True # handle streaming deltas for tools with "auto" tool choice # and reasoning parser @@ -1291,9 +1281,9 @@ async def chat_completion_full_generator( conversation: list[ConversationMessage], tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, - ) -> Union[ErrorResponse, ChatCompletionResponse]: + ) -> ErrorResponse | ChatCompletionResponse: created_time = int(time.time()) - final_res: Optional[RequestOutput] = None + final_res: RequestOutput | None = None try: async for res in result_generator: @@ -1373,7 +1363,10 @@ async def chat_completion_full_generator( if self.reasoning_parser: try: - reasoning_parser = self.reasoning_parser(tokenizer) + reasoning_parser = self.reasoning_parser( + tokenizer, + chat_template_kwargs=request.chat_template_kwargs, # type: ignore + ) except RuntimeError as e: logger.exception("Error in reasoning parser creation.") return self.create_error_response(str(e)) @@ -1543,7 +1536,7 @@ async def chat_completion_full_generator( choices.append(choice_data) if request.echo: - last_msg_content: Union[str, list[dict[str, str]]] = "" + last_msg_content: str | list[dict[str, str]] = "" if ( conversation and "content" in conversation[-1] @@ -1628,7 +1621,7 @@ async def chat_completion_full_generator( def _get_top_logprobs( self, logprobs: dict[int, Logprob], - top_logprobs: Optional[int], + top_logprobs: int | None, tokenizer: AnyTokenizer, should_return_as_token_id: bool, ) -> list[ChatCompletionLogProb]: @@ -1646,16 +1639,16 @@ def _get_top_logprobs( bytes=list(token.encode("utf-8", errors="replace")), ) for i, p in enumerate(logprobs.items()) - if top_logprobs and i < top_logprobs + if (top_logprobs and i < top_logprobs or top_logprobs == -1) ] def _create_chat_logprobs( self, token_ids: GenericSequence[int], - top_logprobs: GenericSequence[Optional[dict[int, Logprob]]], + top_logprobs: GenericSequence[dict[int, Logprob] | None], tokenizer: AnyTokenizer, - num_output_top_logprobs: Optional[int] = None, - return_as_token_id: Optional[bool] = None, + num_output_top_logprobs: int | None = None, + return_as_token_id: bool | None = None, ) -> ChatCompletionLogProbs: """Create OpenAI-style logprobs.""" logprobs_content: list[ChatCompletionLogProbsContent] = [] @@ -1724,7 +1717,7 @@ def _should_stream_with_auto_tool_parsing(self, request: ChatCompletionRequest): def _should_check_for_unstreamed_tool_arg_tokens( self, - delta_message: Optional[DeltaMessage], + delta_message: DeltaMessage | None, output: CompletionOutput, ) -> bool: """ diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index 25e167e9bb0c..45bbe732a680 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -2,13 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from http import HTTPStatus -from typing import Optional, Union, cast +from typing import cast import numpy as np from fastapi import Request from typing_extensions import override -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( @@ -37,7 +36,7 @@ class ClassificationMixin(OpenAIServing): async def _preprocess( self, ctx: ServeContext, - ) -> Optional[ErrorResponse]: + ) -> ErrorResponse | None: """ Process classification inputs: tokenize text, resolve adapters, and prepare model-specific inputs. @@ -71,7 +70,7 @@ async def _preprocess( def _build_response( self, ctx: ServeContext, - ) -> Union[ClassificationResponse, ErrorResponse]: + ) -> ClassificationResponse | ErrorResponse: """ Convert model outputs to a formatted classification response with probabilities and labels. @@ -128,15 +127,13 @@ class ServingClassification(ClassificationMixin): def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, - request_logger: Optional[RequestLogger], + request_logger: RequestLogger | None, log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, - model_config=model_config, models=models, request_logger=request_logger, log_error_stack=log_error_stack, @@ -146,7 +143,7 @@ async def create_classify( self, request: ClassificationRequest, raw_request: Request, - ) -> Union[ClassificationResponse, ErrorResponse]: + ) -> ClassificationResponse | ErrorResponse: model_name = self.models.model_name() request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}" @@ -163,7 +160,7 @@ async def create_classify( def _create_pooling_params( self, ctx: ClassificationServeContext, - ) -> Union[PoolingParams, ErrorResponse]: + ) -> PoolingParams | ErrorResponse: pooling_params = super()._create_pooling_params(ctx) if isinstance(pooling_params, ErrorResponse): return pooling_params diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index ce0a6c0e23e5..44211201d49a 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -5,12 +5,11 @@ import time from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import Sequence as GenericSequence -from typing import Optional, Union, cast +from typing import cast import jinja2 from fastapi import Request -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( @@ -28,14 +27,15 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.renderer import RenderConfig -from vllm.entrypoints.utils import get_max_tokens +from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import as_list, merge_async_iterators +from vllm.utils.async_utils import merge_async_iterators +from vllm.utils.collection_utils import as_list logger = init_logger(__name__) @@ -44,10 +44,9 @@ class OpenAIServingCompletion(OpenAIServing): def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, - request_logger: Optional[RequestLogger], + request_logger: RequestLogger | None, return_tokens_as_token_ids: bool = False, enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, @@ -55,15 +54,14 @@ def __init__( ): super().__init__( engine_client=engine_client, - model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - enable_force_include_usage=enable_force_include_usage, log_error_stack=log_error_stack, ) self.enable_prompt_tokens_details = enable_prompt_tokens_details self.default_sampling_params = self.model_config.get_diff_sampling_param() + self.enable_force_include_usage = enable_force_include_usage if self.default_sampling_params: source = self.model_config.generation_config source = "model" if source == "auto" else source @@ -76,8 +74,8 @@ def __init__( async def create_completion( self, request: CompletionRequest, - raw_request: Optional[Request] = None, - ) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]: + raw_request: Request | None = None, + ) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse: """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/completions/create @@ -169,7 +167,7 @@ async def create_completion( default_sampling_params=self.default_sampling_params, ) - sampling_params: Union[SamplingParams, BeamSearchParams] + sampling_params: SamplingParams | BeamSearchParams if request.use_beam_search: sampling_params = request.to_beam_search_params( max_tokens, self.default_sampling_params @@ -199,9 +197,9 @@ async def create_completion( # Mypy inconsistently requires this second cast in different # environments. It shouldn't be necessary (redundant from above) # but pre-commit in CI fails without it. - engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], engine_prompt) + engine_prompt = cast(EmbedsPrompt | TokensPrompt, engine_prompt) if isinstance(sampling_params, BeamSearchParams): - generator = self.engine_client.beam_search( + generator = self.beam_search( prompt=engine_prompt, request_id=request_id, params=sampling_params, @@ -259,11 +257,10 @@ async def create_completion( num_prompts=num_prompts, tokenizer=tokenizer, request_metadata=request_metadata, - enable_force_include_usage=self.enable_force_include_usage, ) # Non-streaming response - final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts + final_res_batch: list[RequestOutput | None] = [None] * num_prompts try: async for i, res in result_generator: final_res_batch[i] = res @@ -315,7 +312,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: async def completion_stream_generator( self, request: CompletionRequest, - engine_prompts: list[Union[TokensPrompt, EmbedsPrompt]], + engine_prompts: list[TokensPrompt | EmbedsPrompt], result_generator: AsyncIterator[tuple[int, RequestOutput]], request_id: str, created_time: int, @@ -323,7 +320,6 @@ async def completion_stream_generator( num_prompts: int, tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, - enable_force_include_usage: bool, ) -> AsyncGenerator[str, None]: num_choices = 1 if request.n is None else request.n previous_text_lens = [0] * num_choices * num_prompts @@ -334,13 +330,9 @@ async def completion_stream_generator( first_iteration = True stream_options = request.stream_options - if stream_options: - include_usage = stream_options.include_usage or enable_force_include_usage - include_continuous_usage = ( - include_usage and stream_options.continuous_usage_stats - ) - else: - include_usage, include_continuous_usage = False, False + include_usage, include_continuous_usage = should_include_usage( + stream_options, self.enable_force_include_usage + ) try: async for prompt_idx, res in result_generator: @@ -365,7 +357,7 @@ async def completion_stream_generator( num_prompt_tokens[prompt_idx] = len(prompt_token_ids) delta_token_ids: GenericSequence[int] - out_logprobs: Optional[GenericSequence[Optional[dict[int, Logprob]]]] + out_logprobs: GenericSequence[dict[int, Logprob] | None] | None for output in res.outputs: i = output.index + prompt_idx * num_choices @@ -373,7 +365,7 @@ async def completion_stream_generator( # Useful when request.return_token_ids is True # Returning prompt token IDs shares the same logic # with the echo implementation. - prompt_token_ids_to_return: Optional[list[int]] = None + prompt_token_ids_to_return: list[int] | None = None assert request.max_tokens is not None if request.echo and not has_echoed[i]: @@ -527,7 +519,7 @@ def request_output_to_completion_response( prompt_text = final_res.prompt token_ids: GenericSequence[int] - out_logprobs: Optional[GenericSequence[Optional[dict[int, Logprob]]]] + out_logprobs: GenericSequence[dict[int, Logprob] | None] | None for output in final_res.outputs: assert request.max_tokens is not None @@ -620,17 +612,17 @@ def request_output_to_completion_response( def _create_completion_logprobs( self, token_ids: GenericSequence[int], - top_logprobs: GenericSequence[Optional[dict[int, Logprob]]], + top_logprobs: GenericSequence[dict[int, Logprob] | None], num_output_top_logprobs: int, tokenizer: AnyTokenizer, initial_text_offset: int = 0, - return_as_token_id: Optional[bool] = None, + return_as_token_id: bool | None = None, ) -> CompletionLogProbs: """Create logprobs for OpenAI Completion API.""" out_text_offset: list[int] = [] - out_token_logprobs: list[Optional[float]] = [] + out_token_logprobs: list[float | None] = [] out_tokens: list[str] = [] - out_top_logprobs: list[Optional[dict[str, float]]] = [] + out_top_logprobs: list[dict[str, float] | None] = [] last_token_len = 0 @@ -698,7 +690,7 @@ def _create_completion_logprobs( def _build_render_config( self, request: CompletionRequest, - max_input_length: Optional[int] = None, + max_input_length: int | None = None, ) -> RenderConfig: max_input_tokens_len = self.max_model_len - (request.max_tokens or 0) return RenderConfig( diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 5517ab2802e3..51f6106acec3 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,20 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import base64 +import json from collections.abc import AsyncGenerator, Mapping -from typing import Any, Final, Literal, Optional, Union, cast +from typing import Any, Final, cast -import numpy as np import torch from fastapi import Request +from fastapi.responses import Response from typing_extensions import assert_never, override -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( + EmbeddingBytesResponse, EmbeddingChatRequest, EmbeddingCompletionRequest, EmbeddingRequest, @@ -34,33 +33,25 @@ from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.outputs import ( - EmbeddingOutput, EmbeddingRequestOutput, PoolingOutput, PoolingRequestOutput, RequestOutput, ) from vllm.pooling_params import PoolingParams -from vllm.utils import chunk_list +from vllm.utils.async_utils import merge_async_iterators +from vllm.utils.collection_utils import chunk_list +from vllm.utils.serial_utils import ( + EmbedDType, + EncodingFormat, + Endianness, + encode_pooling_bytes, + encode_pooling_output, +) logger = init_logger(__name__) -def _get_embedding( - output: EmbeddingOutput, - encoding_format: Literal["float", "base64"], -) -> Union[list[float], str]: - if encoding_format == "float": - return output.embedding - elif encoding_format == "base64": - # Force to use float32 for base64 encoding - # to match the OpenAI python client behavior - embedding_bytes = np.array(output.embedding, dtype="float32").tobytes() - return base64.b64encode(embedding_bytes).decode("utf-8") - - assert_never(encoding_format) - - class EmbeddingMixin(OpenAIServing): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -81,7 +72,7 @@ def __init__(self, *args, **kwargs): async def _preprocess( self, ctx: ServeContext, - ) -> Optional[ErrorResponse]: + ) -> ErrorResponse | None: ctx = cast(EmbeddingServeContext, ctx) try: ctx.lora_request = self._maybe_get_adapters(ctx.request) @@ -131,38 +122,70 @@ def _build_render_config(self, request: EmbeddingCompletionRequest) -> RenderCon def _build_response( self, ctx: ServeContext, - ) -> Union[EmbeddingResponse, ErrorResponse]: - items: list[EmbeddingResponseData] = [] - num_prompt_tokens = 0 - + ) -> EmbeddingResponse | Response | ErrorResponse: final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch) - for idx, final_res in enumerate(final_res_batch_checked): - embedding_res = EmbeddingRequestOutput.from_base(final_res) + encoding_format: EncodingFormat = ctx.request.encoding_format + embed_dtype: EmbedDType = ctx.request.embed_dtype + endianness: Endianness = ctx.request.endianness + + def encode_float_base64(): + items: list[EmbeddingResponseData] = [] + num_prompt_tokens = 0 + + for idx, final_res in enumerate(final_res_batch_checked): + item = EmbeddingResponseData( + index=idx, + embedding=encode_pooling_output( + final_res, + encoding_format=encoding_format, + embed_dtype=embed_dtype, + endianness=endianness, + ), + ) + prompt_token_ids = final_res.prompt_token_ids + + items.append(item) + num_prompt_tokens += len(prompt_token_ids) - item = EmbeddingResponseData( - index=idx, - embedding=_get_embedding( - embedding_res.outputs, ctx.request.encoding_format - ), + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, ) - prompt_token_ids = final_res.prompt_token_ids - items.append(item) - num_prompt_tokens += len(prompt_token_ids) + return EmbeddingResponse( + id=ctx.request_id, + created=ctx.created_time, + model=ctx.model_name, + data=items, + usage=usage, + ) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - total_tokens=num_prompt_tokens, - ) + def encode_bytes(): + body, items, usage = encode_pooling_bytes( + pooling_outputs=final_res_batch_checked, + embed_dtype=embed_dtype, + endianness=endianness, + ) - return EmbeddingResponse( - id=ctx.request_id, - created=ctx.created_time, - model=ctx.model_name, - data=items, - usage=usage, - ) + metadata = { + "id": ctx.request_id, + "created": ctx.created_time, + "model": ctx.model_name, + "data": items, + "usage": usage, + } + return EmbeddingBytesResponse( + body=body, + metadata=json.dumps(metadata), + ) + + if encoding_format == "float" or encoding_format == "base64": + return encode_float_base64() + elif encoding_format == "bytes": + return encode_bytes() + else: + assert_never(encoding_format) def _get_max_position_embeddings(self) -> int: """Get the model's effective maximum sequence length for chunking.""" @@ -315,9 +338,9 @@ async def _create_single_prompt_generator( ctx: EmbeddingServeContext, engine_prompt: EngineTokensPrompt, pooling_params: PoolingParams, - trace_headers: Optional[Mapping[str, str]], + trace_headers: Mapping[str, str] | None, prompt_index: int, - ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: + ) -> AsyncGenerator[RequestOutput | PoolingRequestOutput, None]: """Create a generator for a single prompt using standard processing.""" request_id_item = f"{ctx.request_id}-{prompt_index}" @@ -342,7 +365,7 @@ async def _create_single_prompt_generator( async def _prepare_generators( self, ctx: ServeContext, - ) -> Optional[ErrorResponse]: + ) -> ErrorResponse | None: """Override to support chunked processing.""" ctx = cast(EmbeddingServeContext, ctx) @@ -355,7 +378,7 @@ async def _prepare_generators( # Custom logic for chunked processing generators: list[ - AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None] + AsyncGenerator[RequestOutput | PoolingRequestOutput, None] ] = [] try: @@ -400,8 +423,6 @@ async def _prepare_generators( ) generators.append(generator) - from vllm.utils import merge_async_iterators - ctx.result_generator = merge_async_iterators(*generators) return None @@ -414,7 +435,7 @@ async def _prepare_generators( async def _collect_batch( self, ctx: ServeContext, - ) -> Optional[ErrorResponse]: + ) -> ErrorResponse | None: """Collect and aggregate batch results with support for chunked processing. @@ -523,9 +544,7 @@ async def _collect_batch( ) # Finalize aggregated results - final_res_batch: list[ - Union[PoolingRequestOutput, EmbeddingRequestOutput] - ] = [] + final_res_batch: list[PoolingRequestOutput | EmbeddingRequestOutput] = [] num_prompts = len(ctx.engine_prompts) for prompt_idx in range(num_prompts): @@ -564,6 +583,7 @@ async def _collect_batch( request_id=aggregator["request_id"], prompt_token_ids=original_token_ids, outputs=pooling_output_data, + num_cached_tokens=0, finished=True, ) @@ -582,7 +602,7 @@ async def _collect_batch( ) ctx.final_res_batch = cast( - list[Union[RequestOutput, PoolingRequestOutput]], final_res_batch + list[RequestOutput | PoolingRequestOutput], final_res_batch ) return None @@ -597,18 +617,16 @@ class OpenAIServingEmbedding(EmbeddingMixin): def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, - request_logger: Optional[RequestLogger], - chat_template: Optional[str], + request_logger: RequestLogger | None, + chat_template: str | None, chat_template_content_format: ChatTemplateContentFormatOption, trust_request_chat_template: bool = False, log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, - model_config=model_config, models=models, request_logger=request_logger, log_error_stack=log_error_stack, @@ -621,8 +639,8 @@ def __init__( async def create_embedding( self, request: EmbeddingRequest, - raw_request: Optional[Request] = None, - ) -> Union[EmbeddingResponse, ErrorResponse]: + raw_request: Request | None = None, + ) -> EmbeddingResponse | ErrorResponse: """ Embedding API similar to OpenAI's API. @@ -650,7 +668,7 @@ async def create_embedding( def _create_pooling_params( self, ctx: ServeContext[EmbeddingRequest], - ) -> Union[PoolingParams, ErrorResponse]: + ) -> PoolingParams | ErrorResponse: pooling_params = super()._create_pooling_params(ctx) if isinstance(pooling_params, ErrorResponse): return pooling_params @@ -665,7 +683,7 @@ def _create_pooling_params( async def _preprocess( self, ctx: ServeContext, - ) -> Optional[ErrorResponse]: + ) -> ErrorResponse | None: if isinstance(ctx.request, EmbeddingChatRequest): error_check_ret = self._validate_chat_template( request_chat_template=ctx.request.chat_template, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 6ddde23b4a34..af5a423134fb 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio import json import sys import time import traceback -from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence +from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from http import HTTPStatus -from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union +from typing import Any, ClassVar, Generic, TypeAlias, TypeVar import torch from fastapi import Request @@ -15,17 +16,13 @@ from starlette.datastructures import Headers from typing_extensions import TypeIs -from vllm.entrypoints.utils import _validate_truncation_size -from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.processor import Processor - if sys.version_info >= (3, 12): from typing import TypedDict else: from typing_extensions import TypedDict import vllm.envs as envs -from vllm.config import ModelConfig +from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, @@ -66,11 +63,16 @@ TranslationRequest, ) from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.openai.tool_parsers import ToolParser +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig +from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs.data import PromptType from vllm.inputs.data import TokensPrompt as EngineTokensPrompt -from vllm.inputs.parse import PromptComponents, get_prompt_components +from vllm.inputs.parse import ( + PromptComponents, + get_prompt_components, + is_explicit_encoder_decoder_prompt, +) from vllm.logger import init_logger from vllm.logprobs import Logprob, PromptLogprobs from vllm.lora.request import LoRARequest @@ -78,8 +80,9 @@ MultiModalDataDict, MultiModalUUIDDict, ) -from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams +from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.tracing import ( contains_trace_headers, @@ -87,48 +90,50 @@ log_tracing_disabled_warning, ) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import ( +from vllm.utils import random_uuid +from vllm.utils.async_utils import ( AsyncMicrobatchTokenizer, - is_list_of, + collect_from_async_generator, make_async, merge_async_iterators, - random_uuid, ) +from vllm.utils.collection_utils import is_list_of +from vllm.v1.engine import EngineCoreRequest logger = init_logger(__name__) -CompletionLikeRequest = Union[ - CompletionRequest, - DetokenizeRequest, - EmbeddingCompletionRequest, - RerankRequest, - ClassificationRequest, - ScoreRequest, - TokenizeCompletionRequest, -] - -ChatLikeRequest = Union[ - ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest -] -SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest] -AnyRequest = Union[ - CompletionLikeRequest, - ChatLikeRequest, - SpeechToTextRequest, - ResponsesRequest, - IOProcessorRequest, -] +CompletionLikeRequest: TypeAlias = ( + CompletionRequest + | DetokenizeRequest + | EmbeddingCompletionRequest + | RerankRequest + | ClassificationRequest + | ScoreRequest + | TokenizeCompletionRequest +) -AnyResponse = Union[ - CompletionResponse, - ChatCompletionResponse, - EmbeddingResponse, - TranscriptionResponse, - TokenizeResponse, - PoolingResponse, - ClassificationResponse, - ScoreResponse, -] +ChatLikeRequest: TypeAlias = ( + ChatCompletionRequest | EmbeddingChatRequest | TokenizeChatRequest +) +SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest +AnyRequest: TypeAlias = ( + CompletionLikeRequest + | ChatLikeRequest + | SpeechToTextRequest + | ResponsesRequest + | IOProcessorRequest +) + +AnyResponse: TypeAlias = ( + CompletionResponse + | ChatCompletionResponse + | EmbeddingResponse + | TranscriptionResponse + | TokenizeResponse + | PoolingResponse + | ClassificationResponse + | ScoreResponse +) class TextTokensPrompt(TypedDict): @@ -140,7 +145,7 @@ class EmbedsPrompt(TypedDict): prompt_embeds: torch.Tensor -RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt] +RequestPrompt: TypeAlias = list[int] | str | TextTokensPrompt | EmbedsPrompt def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]: @@ -168,8 +173,8 @@ class RequestProcessingMixin(BaseModel): handling prompt preparation and engine input. """ - request_prompts: Optional[Sequence[RequestPrompt]] = [] - engine_prompts: Optional[list[EngineTokensPrompt]] = [] + request_prompts: Sequence[RequestPrompt] | None = [] + engine_prompts: list[EngineTokensPrompt] | None = [] model_config = ConfigDict(arbitrary_types_allowed=True) @@ -180,10 +185,10 @@ class ResponseGenerationMixin(BaseModel): managing result generators and final batch results. """ - result_generator: Optional[ - AsyncGenerator[tuple[int, Union[RequestOutput, PoolingRequestOutput]], None] - ] = None - final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field( + result_generator: ( + AsyncGenerator[tuple[int, RequestOutput | PoolingRequestOutput], None] | None + ) = None + final_res_batch: list[RequestOutput | PoolingRequestOutput] = Field( default_factory=list ) @@ -198,14 +203,14 @@ class ServeContext( ): # Shared across all requests request: RequestT - raw_request: Optional[Request] = None + raw_request: Request | None = None model_name: str request_id: str created_time: int = Field(default_factory=lambda: int(time.time())) - lora_request: Optional[LoRARequest] = None + lora_request: LoRARequest | None = None # Shared across most requests - tokenizer: Optional[AnyTokenizer] = None + tokenizer: AnyTokenizer | None = None # `protected_namespaces` resolves Pydantic v2's warning # on conflict with protected namespace "model_" @@ -219,7 +224,7 @@ class ServeContext( class EmbeddingServeContext(ServeContext[EmbeddingRequest]): - chat_template: Optional[str] = None + chat_template: str | None = None chat_template_content_format: ChatTemplateContentFormatOption @@ -240,26 +245,20 @@ class OpenAIServing: def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, - request_logger: Optional[RequestLogger], + request_logger: RequestLogger | None, return_tokens_as_token_ids: bool = False, - enable_force_include_usage: bool = False, log_error_stack: bool = False, ): super().__init__() self.engine_client = engine_client - self.model_config = model_config - self.max_model_len = model_config.max_model_len self.models = models self.request_logger = request_logger self.return_tokens_as_token_ids = return_tokens_as_token_ids - self.enable_force_include_usage = enable_force_include_usage - self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) self._apply_mistral_chat_template_async = make_async( apply_mistral_chat_template, executor=self._tokenizer_executor @@ -268,14 +267,244 @@ def __init__( self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {} self.log_error_stack = log_error_stack - async def _get_processor(self) -> Processor: - if not hasattr(self, "_processor"): - vllm_config = await self.engine_client.get_vllm_config() - self._processor = Processor(vllm_config) + self.processor = self.models.processor + self.io_processor = self.models.io_processor + self.model_config = self.models.model_config + self.max_model_len = self.model_config.max_model_len + + def _get_tool_parser( + self, tool_parser_name: str | None = None, enable_auto_tools: bool = False + ) -> Callable[[AnyTokenizer], ToolParser] | None: + """Get the tool parser based on the name.""" + parser = None + if not enable_auto_tools or tool_parser_name is None: + return parser + logger.info( + '"auto" tool choice has been enabled please note that while' + " the parallel_tool_calls client option is preset for " + "compatibility reasons, it will be ignored." + ) + + try: + if tool_parser_name == "pythonic" and self.model_config.model.startswith( + "meta-llama/Llama-3.2" + ): + logger.warning( + "Llama3.2 models may struggle to emit valid pythonic tool calls" + ) + parser = ToolParserManager.get_tool_parser(tool_parser_name) + except Exception as e: + raise TypeError( + "Error: --enable-auto-tool-choice requires " + f"tool_parser:'{tool_parser_name}' which has not " + "been registered" + ) from e + return parser + + def _get_reasoning_parser( + self, + reasoning_parser_name: str, + ) -> Callable[[AnyTokenizer], ReasoningParser] | None: + """Get the reasoning parser based on the name.""" + parser = None + if not reasoning_parser_name: + return None + try: + parser = ReasoningParserManager.get_reasoning_parser(reasoning_parser_name) + assert parser is not None + except Exception as e: + raise TypeError(f"{reasoning_parser_name=} has not been registered") from e + return parser + + async def reset_mm_cache(self) -> None: + self.processor.clear_mm_cache() + await self.engine_client.reset_mm_cache() + + async def beam_search( + self, + prompt: PromptType, + request_id: str, + params: BeamSearchParams, + lora_request: LoRARequest | None = None, + ) -> AsyncGenerator[RequestOutput, None]: + beam_width = params.beam_width + max_tokens = params.max_tokens + ignore_eos = params.ignore_eos + temperature = params.temperature + length_penalty = params.length_penalty + include_stop_str_in_output = params.include_stop_str_in_output + + processor = self.processor + tokenizer = processor.tokenizer + if tokenizer is None: + raise ValueError( + "You cannot use beam search when `skip_tokenizer_init` is True" + ) + + eos_token_id: int = tokenizer.eos_token_id # type: ignore + + if is_explicit_encoder_decoder_prompt(prompt): + raise NotImplementedError + else: + processed_inputs = processor.input_preprocessor._prompt_to_llm_inputs( + prompt + ) - return self._processor + if processed_inputs["type"] == "embeds": + raise NotImplementedError + + # This is a workaround to fix multimodal beam search; this is a + # bandaid fix for 2 small problems: + # 1. Multi_modal_data on the processed_inputs currently resolves to + # `None`. + # 2. preprocessing above expands the multimodal placeholders. However, + # this happens again in generation, so the double expansion causes + # a mismatch. + # TODO - would be ideal to handle this more gracefully. + prompt_text: str | None + prompt_token_ids: list[int] + multi_modal_data: MultiModalDataDict | None + if isinstance(prompt, str): + prompt_text = prompt + prompt_token_ids = [] + multi_modal_data = None + else: + prompt_text = prompt.get("prompt") # type: ignore + prompt_token_ids = prompt.get("prompt_token_ids", []) # type: ignore + multi_modal_data = prompt.get("multi_modal_data") # type: ignore + + mm_processor_kwargs: dict[str, Any] | None = processed_inputs.get( + "mm_processor_kwargs" + ) # type: ignore + + tokenized_length = len(prompt_token_ids) + + sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty) + + beam_search_params = SamplingParams( + logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature, + ) + all_beams = [ + BeamSearchSequence( + tokens=prompt_token_ids, + cum_logprob=0, + logprobs=[], + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + lora_request=lora_request, + ) + ] + completed = [] + + for _ in range(max_tokens): + prompts_batch, lora_req_batch = zip( + *[ + ( + EngineTokensPrompt( + prompt_token_ids=beam.tokens, + multi_modal_data=beam.multi_modal_data, + mm_processor_kwargs=beam.mm_processor_kwargs, + ), + beam.lora_request, + ) + for beam in all_beams + ] + ) + + tasks = [] + request_id_batch = f"{request_id}-{random_uuid()}" + + for i, (individual_prompt, lora_req) in enumerate( + zip(prompts_batch, lora_req_batch) + ): + request_id_item = f"{request_id_batch}-beam-{i}" + task = asyncio.create_task( + collect_from_async_generator( + self.engine_client.generate( + individual_prompt, + beam_search_params, + request_id_item, + lora_request=lora_req, + ) + ) + ) + tasks.append(task) + + output = [x[0] for x in await asyncio.gather(*tasks)] + + new_beams = [] + for i, current_beam in enumerate(all_beams): + result = output[i] + + if result.outputs[0].logprobs is not None: + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + if token_id == eos_token_id and not ignore_eos: + completed.append( + BeamSearchSequence( + tokens=current_beam.tokens + [token_id] + if include_stop_str_in_output + else current_beam.tokens, + logprobs=current_beam.logprobs + [logprobs], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + finish_reason="stop", + stop_reason=eos_token_id, + ) + ) + else: + new_beams.append( + BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + [logprobs], + lora_request=current_beam.lora_request, + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + multi_modal_data=current_beam.multi_modal_data, + mm_processor_kwargs=current_beam.mm_processor_kwargs, + ) + ) + + sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) + all_beams = sorted_beams[:beam_width] + + completed.extend(all_beams) + sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + if beam.tokens[-1] == eos_token_id and not ignore_eos: + # Skip the eos token in the text. + tokens = beam.tokens[tokenized_length:-1] + else: + tokens = beam.tokens[tokenized_length:] + beam.text = tokenizer.decode(tokens) + + yield RequestOutput( + request_id=request_id, + prompt=prompt_text, + outputs=[ + CompletionOutput( + text=beam.text, # type: ignore + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens[tokenized_length:], + index=i, + logprobs=beam.logprobs, + finish_reason=beam.finish_reason + if beam.finish_reason is not None + else "length", + stop_reason=beam.stop_reason, + ) + for (i, beam) in enumerate(best_beams) + ], + finished=True, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, + ) - def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer: + def _get_renderer(self, tokenizer: AnyTokenizer | None) -> BaseRenderer: """ Get a Renderer instance with the provided tokenizer. Uses shared async tokenizer pool for efficiency. @@ -313,7 +542,7 @@ def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer: async def _preprocess( self, ctx: ServeContext, - ) -> Optional[ErrorResponse]: + ) -> ErrorResponse | None: """ Default preprocessing hook. Subclasses may override to prepare `ctx` (classification, embedding, etc.). @@ -323,7 +552,7 @@ async def _preprocess( def _build_response( self, ctx: ServeContext, - ) -> Union[AnyResponse, ErrorResponse]: + ) -> AnyResponse | ErrorResponse: """ Default response builder. Subclass may override this method to return the appropriate response object. @@ -333,8 +562,8 @@ def _build_response( async def handle( self, ctx: ServeContext, - ) -> Union[AnyResponse, ErrorResponse]: - generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None] + ) -> AnyResponse | ErrorResponse: + generation: AsyncGenerator[AnyResponse | ErrorResponse, None] generation = self._pipeline(ctx) async for response in generation: @@ -345,7 +574,7 @@ async def handle( async def _pipeline( self, ctx: ServeContext, - ) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]: + ) -> AsyncGenerator[AnyResponse | ErrorResponse, None]: """Execute the request processing pipeline yielding responses.""" if error := await self._check_model(ctx.request): yield error @@ -366,7 +595,7 @@ async def _pipeline( yield self._build_response(ctx) - def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]: + def _validate_request(self, ctx: ServeContext) -> ErrorResponse | None: truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None) if ( @@ -383,7 +612,7 @@ def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]: def _create_pooling_params( self, ctx: ServeContext, - ) -> Union[PoolingParams, ErrorResponse]: + ) -> PoolingParams | ErrorResponse: if not hasattr(ctx.request, "to_pooling_params"): return self.create_error_response( "Request type does not support pooling parameters" @@ -394,10 +623,10 @@ def _create_pooling_params( async def _prepare_generators( self, ctx: ServeContext, - ) -> Optional[ErrorResponse]: + ) -> ErrorResponse | None: """Schedule the request and get the result generator.""" generators: list[ - AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None] + AsyncGenerator[RequestOutput | PoolingRequestOutput, None] ] = [] try: @@ -446,14 +675,14 @@ async def _prepare_generators( async def _collect_batch( self, ctx: ServeContext, - ) -> Optional[ErrorResponse]: + ) -> ErrorResponse | None: """Collect batch results from the result generator.""" try: if ctx.engine_prompts is None: return self.create_error_response("Engine prompts not available") num_prompts = len(ctx.engine_prompts) - final_res_batch: list[Optional[Union[RequestOutput, PoolingRequestOutput]]] + final_res_batch: list[RequestOutput | PoolingRequestOutput | None] final_res_batch = [None] * num_prompts if ctx.result_generator is None: @@ -506,7 +735,7 @@ def create_streaming_error_response( async def _check_model( self, request: AnyRequest, - ) -> Optional[ErrorResponse]: + ) -> ErrorResponse | None: error_response = None if self._is_model_supported(request.model): @@ -532,9 +761,7 @@ async def _check_model( status_code=HTTPStatus.NOT_FOUND, ) - def _get_active_default_mm_loras( - self, request: AnyRequest - ) -> Optional[LoRARequest]: + def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None: """Determine if there are any active default multimodal loras.""" # TODO: Currently this is only enabled for chat completions # to be better aligned with only being enabled for .generate @@ -561,7 +788,7 @@ def _maybe_get_adapters( self, request: AnyRequest, supports_default_mm_loras: bool = False, - ) -> Optional[LoRARequest]: + ) -> LoRARequest | None: if request.model in self.models.lora_requests: return self.models.lora_requests[request.model] @@ -645,7 +872,7 @@ async def _normalize_prompt_tokens_to_input( self, request: AnyRequest, prompt_ids: list[int], - tokenizer: Optional[AnyTokenizer], + tokenizer: AnyTokenizer | None, ) -> TextTokensPrompt: truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) @@ -740,7 +967,7 @@ async def _tokenize_prompt_input_async( self, request: AnyRequest, tokenizer: AnyTokenizer, - prompt_input: Union[str, list[int]], + prompt_input: str | list[int], add_special_tokens: bool = True, ) -> TextTokensPrompt: """ @@ -759,7 +986,7 @@ async def _tokenize_prompt_inputs_async( self, request: AnyRequest, tokenizer: AnyTokenizer, - prompt_inputs: Iterable[Union[str, list[int]]], + prompt_inputs: Iterable[str | list[int]], add_special_tokens: bool = True, ) -> AsyncGenerator[TextTokensPrompt, None]: """ @@ -782,10 +1009,10 @@ async def _tokenize_prompt_inputs_async( def _validate_chat_template( self, - request_chat_template: Optional[str], - chat_template_kwargs: Optional[dict[str, Any]], + request_chat_template: str | None, + chat_template_kwargs: dict[str, Any] | None, trust_request_chat_template: bool, - ) -> Optional[ErrorResponse]: + ) -> ErrorResponse | None: if not trust_request_chat_template and ( request_chat_template is not None or ( @@ -802,17 +1029,17 @@ def _validate_chat_template( async def _preprocess_chat( self, - request: Union[ChatLikeRequest, ResponsesRequest], + request: ChatLikeRequest | ResponsesRequest, tokenizer: AnyTokenizer, messages: list[ChatCompletionMessageParam], - chat_template: Optional[str], + chat_template: str | None, chat_template_content_format: ChatTemplateContentFormatOption, add_generation_prompt: bool = True, continue_final_message: bool = False, - tool_dicts: Optional[list[dict[str, Any]]] = None, - documents: Optional[list[dict[str, str]]] = None, - chat_template_kwargs: Optional[dict[str, Any]] = None, - tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, + tool_dicts: list[dict[str, Any]] | None = None, + documents: list[dict[str, str]] | None = None, + chat_template_kwargs: dict[str, Any] | None = None, + tool_parser: Callable[[AnyTokenizer], ToolParser] | None = None, add_special_tokens: bool = False, ) -> tuple[ list[ConversationMessage], @@ -844,7 +1071,7 @@ async def _preprocess_chat( ) _chat_template_kwargs.update(chat_template_kwargs or {}) - request_prompt: Union[str, list[int]] + request_prompt: str | list[int] if tokenizer is None: request_prompt = "placeholder" @@ -926,10 +1153,10 @@ async def _process_inputs( self, request_id: str, engine_prompt: PromptType, - params: Union[SamplingParams, PoolingParams], + params: SamplingParams | PoolingParams, *, - lora_request: Optional[LoRARequest], - trace_headers: Optional[Mapping[str, str]], + lora_request: LoRARequest | None, + trace_headers: Mapping[str, str] | None, priority: int, ) -> tuple[EngineCoreRequest, dict[str, Any]]: """Use the Processor to process inputs for AsyncLLM.""" @@ -938,8 +1165,7 @@ async def _process_inputs( self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs ) - processor = await self._get_processor() - engine_request = processor.process_inputs( + engine_request = self.processor.process_inputs( request_id, engine_prompt, params, @@ -957,7 +1183,7 @@ async def _generate_with_builtin_tools( engine_prompt: EngineTokensPrompt, sampling_params: SamplingParams, context: ConversationContext, - lora_request: Optional[LoRARequest] = None, + lora_request: LoRARequest | None = None, priority: int = 0, **kwargs, ): @@ -1019,7 +1245,7 @@ async def _generate_with_builtin_tools( def _get_prompt_components( self, - prompt: Union[RequestPrompt, PromptType], + prompt: RequestPrompt | PromptType, ) -> PromptComponents: if isinstance(prompt, list): return PromptComponents(token_ids=prompt) @@ -1029,9 +1255,9 @@ def _get_prompt_components( def _log_inputs( self, request_id: str, - inputs: Union[RequestPrompt, PromptType], - params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]], - lora_request: Optional[LoRARequest], + inputs: RequestPrompt | PromptType, + params: SamplingParams | PoolingParams | BeamSearchParams | None, + lora_request: LoRARequest | None, ) -> None: if self.request_logger is None: return @@ -1050,7 +1276,7 @@ def _log_inputs( async def _get_trace_headers( self, headers: Headers, - ) -> Optional[Mapping[str, str]]: + ) -> Mapping[str, str] | None: is_tracing_enabled = await self.engine_client.is_tracing_enabled() if is_tracing_enabled: @@ -1063,8 +1289,8 @@ async def _get_trace_headers( @staticmethod def _base_request_id( - raw_request: Optional[Request], default: Optional[str] = None - ) -> Optional[str]: + raw_request: Request | None, default: str | None = None + ) -> str | None: """Pulls the request id to use from a header, if provided""" default = default or random_uuid() if raw_request is None: @@ -1086,15 +1312,15 @@ def _get_decoded_token( return logprob.decoded_token return tokenizer.decode(token_id) - def _is_model_supported(self, model_name: Optional[str]) -> bool: + def _is_model_supported(self, model_name: str | None) -> bool: if not model_name: return True return self.models.is_base_model(model_name) def clamp_prompt_logprobs( - prompt_logprobs: Union[PromptLogprobs, None], -) -> Union[PromptLogprobs, None]: + prompt_logprobs: PromptLogprobs | None, +) -> PromptLogprobs | None: if prompt_logprobs is None: return prompt_logprobs diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py index d2a58a487a76..9b7deb40b93f 100644 --- a/vllm/entrypoints/openai/serving_models.py +++ b/vllm/entrypoints/openai/serving_models.py @@ -5,9 +5,7 @@ from collections import defaultdict from dataclasses import dataclass from http import HTTPStatus -from typing import Optional, Union -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.openai.protocol import ( ErrorInfo, @@ -36,7 +34,7 @@ class BaseModelPath: class LoRAModulePath: name: str path: str - base_model_name: Optional[str] = None + base_model_name: str | None = None class OpenAIServingModels: @@ -51,18 +49,14 @@ class OpenAIServingModels: def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, base_model_paths: list[BaseModelPath], *, - lora_modules: Optional[list[LoRAModulePath]] = None, + lora_modules: list[LoRAModulePath] | None = None, ): super().__init__() - self.base_model_paths = base_model_paths - - self.max_model_len = model_config.max_model_len self.engine_client = engine_client - self.model_config = model_config + self.base_model_paths = base_model_paths self.static_lora_modules = lora_modules self.lora_requests: dict[str, LoRARequest] = {} @@ -75,6 +69,11 @@ def __init__( ) self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock) + self.processor = self.engine_client.processor + self.io_processor = self.engine_client.io_processor + self.model_config = self.engine_client.model_config + self.max_model_len = self.model_config.max_model_len + async def init_static_loras(self): """Loads all static LoRA modules. Raises if any fail to load""" @@ -93,7 +92,7 @@ async def init_static_loras(self): def is_base_model(self, model_name) -> bool: return any(model.name == model_name for model in self.base_model_paths) - def model_name(self, lora_request: Optional[LoRARequest] = None) -> str: + def model_name(self, lora_request: LoRARequest | None = None) -> str: """Returns the appropriate model name depending on the availability and support of the LoRA or base model. Parameters: @@ -132,8 +131,8 @@ async def show_available_models(self) -> ModelList: return ModelList(data=model_cards) async def load_lora_adapter( - self, request: LoadLoRAAdapterRequest, base_model_name: Optional[str] = None - ) -> Union[ErrorResponse, str]: + self, request: LoadLoRAAdapterRequest, base_model_name: str | None = None + ) -> ErrorResponse | str: lora_name = request.lora_name # Ensure atomicity based on the lora name @@ -173,7 +172,7 @@ async def load_lora_adapter( async def unload_lora_adapter( self, request: UnloadLoRAAdapterRequest - ) -> Union[ErrorResponse, str]: + ) -> ErrorResponse | str: lora_name = request.lora_name # Ensure atomicity based on the lora name @@ -189,7 +188,7 @@ async def unload_lora_adapter( async def _check_load_lora_adapter_request( self, request: LoadLoRAAdapterRequest - ) -> Optional[ErrorResponse]: + ) -> ErrorResponse | None: # Check if both 'lora_name' and 'lora_path' are provided if not request.lora_name or not request.lora_path: return create_error_response( @@ -211,7 +210,7 @@ async def _check_load_lora_adapter_request( async def _check_unload_lora_adapter_request( self, request: UnloadLoRAAdapterRequest - ) -> Optional[ErrorResponse]: + ) -> ErrorResponse | None: # Check if 'lora_name' is not provided return an error if not request.lora_name: return create_error_response( @@ -230,7 +229,7 @@ async def _check_unload_lora_adapter_request( return None - async def resolve_lora(self, lora_name: str) -> Union[LoRARequest, ErrorResponse]: + async def resolve_lora(self, lora_name: str) -> LoRARequest | ErrorResponse: """Attempt to resolve a LoRA adapter using available resolvers. Args: diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 390b388e303c..568896ccbf1b 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -2,18 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio -import base64 +import json import time from collections.abc import AsyncGenerator -from typing import Final, Literal, Optional, Union, cast +from typing import Final, cast import jinja2 -import numpy as np -import torch from fastapi import Request from typing_extensions import assert_never -from vllm.config import VllmConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger @@ -21,6 +18,7 @@ ErrorResponse, IOProcessorRequest, IOProcessorResponse, + PoolingBytesResponse, PoolingChatRequest, PoolingCompletionRequest, PoolingRequest, @@ -33,61 +31,50 @@ from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.utils import _validate_truncation_size from vllm.logger import init_logger -from vllm.outputs import PoolingOutput, PoolingRequestOutput -from vllm.plugins.io_processors import get_io_processor -from vllm.utils import merge_async_iterators +from vllm.outputs import PoolingRequestOutput +from vllm.tasks import PoolingTask, SupportedTask +from vllm.utils.async_utils import merge_async_iterators +from vllm.utils.serial_utils import ( + EmbedDType, + EncodingFormat, + Endianness, + encode_pooling_bytes, + encode_pooling_output, +) logger = init_logger(__name__) -def _get_data( - output: PoolingOutput, - encoding_format: Literal["float", "base64"], -) -> Union[list[float], str]: - if encoding_format == "float": - return output.data.tolist() - elif encoding_format == "base64": - # Force to use float32 for base64 encoding - # to match the OpenAI python client behavior - pt_float32 = output.data.to(dtype=torch.float32) - pooling_bytes = np.array(pt_float32, dtype="float32").tobytes() - return base64.b64encode(pooling_bytes).decode("utf-8") - - assert_never(encoding_format) - - class OpenAIServingPooling(OpenAIServing): def __init__( self, engine_client: EngineClient, - vllm_config: VllmConfig, models: OpenAIServingModels, *, - request_logger: Optional[RequestLogger], - chat_template: Optional[str], + supported_tasks: tuple[SupportedTask, ...], + request_logger: RequestLogger | None, + chat_template: str | None, chat_template_content_format: ChatTemplateContentFormatOption, trust_request_chat_template: bool = False, log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, - model_config=vllm_config.model_config, models=models, request_logger=request_logger, log_error_stack=log_error_stack, ) + self.supported_tasks = supported_tasks self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format self.trust_request_chat_template = trust_request_chat_template - io_processor_plugin = self.model_config.io_processor_plugin - self.io_processor = get_io_processor(vllm_config, io_processor_plugin) async def create_pooling( self, request: PoolingRequest, - raw_request: Optional[Request] = None, - ) -> Union[PoolingResponse, IOProcessorResponse, ErrorResponse]: + raw_request: Request | None = None, + ) -> PoolingResponse | IOProcessorResponse | PoolingBytesResponse | ErrorResponse: """ See https://platform.openai.com/docs/api-reference/embeddings/create for the API specification. This API mimics the OpenAI Embedding API. @@ -174,10 +161,28 @@ async def create_pooling( # Schedule the request and get the result generator. generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] try: - pooling_params = request.to_pooling_params() + if is_io_processor_request: + assert self.io_processor is not None and isinstance( + request, IOProcessorRequest + ) + pooling_params = self.io_processor.validate_or_generate_params() + else: + pooling_params = request.to_pooling_params() + + pooling_task: PoolingTask + if "token_embed" in self.supported_tasks: + pooling_task = "token_embed" + elif "token_classify" in self.supported_tasks: + pooling_task = "token_classify" + elif "plugin" in self.supported_tasks: + pooling_task = "plugin" + else: + return self.create_error_response( + f"pooling_task must be one of {self.supported_tasks}." + ) try: - pooling_params.verify("encode", self.model_config) + pooling_params.verify(pooling_task, self.model_config) except ValueError as e: return self.create_error_response(str(e)) @@ -225,7 +230,7 @@ async def create_pooling( num_prompts = len(engine_prompts) # Non-streaming response - final_res_batch: list[Optional[PoolingRequestOutput]] + final_res_batch: list[PoolingRequestOutput | None] final_res_batch = [None] * num_prompts try: async for i, res in result_generator: @@ -241,6 +246,8 @@ async def create_pooling( created_time, model_name, request.encoding_format, + request.embed_dtype, + request.endianness, ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") @@ -256,33 +263,67 @@ def request_output_to_pooling_response( request_id: str, created_time: int, model_name: str, - encoding_format: Literal["float", "base64"], - ) -> PoolingResponse: - items: list[PoolingResponseData] = [] - num_prompt_tokens = 0 - - for idx, final_res in enumerate(final_res_batch): - item = PoolingResponseData( - index=idx, - data=_get_data(final_res.outputs, encoding_format), + encoding_format: EncodingFormat, + embed_dtype: EmbedDType, + endianness: Endianness, + ) -> PoolingResponse | PoolingBytesResponse: + def encode_float_base64(): + items: list[PoolingResponseData] = [] + num_prompt_tokens = 0 + + for idx, final_res in enumerate(final_res_batch): + item = PoolingResponseData( + index=idx, + data=encode_pooling_output( + final_res, + encoding_format=encoding_format, + embed_dtype=embed_dtype, + endianness=endianness, + ), + ) + prompt_token_ids = final_res.prompt_token_ids + + items.append(item) + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, ) - prompt_token_ids = final_res.prompt_token_ids - items.append(item) - num_prompt_tokens += len(prompt_token_ids) + return PoolingResponse( + id=request_id, + created=created_time, + model=model_name, + data=items, + usage=usage, + ) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - total_tokens=num_prompt_tokens, - ) + def encode_bytes(): + body, items, usage = encode_pooling_bytes( + pooling_outputs=final_res_batch, + embed_dtype=embed_dtype, + endianness=endianness, + ) - return PoolingResponse( - id=request_id, - created=created_time, - model=model_name, - data=items, - usage=usage, - ) + metadata = { + "id": request_id, + "created": created_time, + "model": model_name, + "data": items, + "usage": usage, + } + return PoolingBytesResponse( + body=body, + metadata=json.dumps(metadata), + ) + + if encoding_format == "float" or encoding_format == "base64": + return encode_float_base64() + elif encoding_format == "bytes": + return encode_bytes() + else: + assert_never(encoding_format) def _build_render_config(self, request: PoolingCompletionRequest) -> RenderConfig: return RenderConfig( diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 998c279eea04..1fdb6997bc0a 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -6,11 +6,11 @@ import time import uuid from collections import deque -from collections.abc import AsyncGenerator, AsyncIterator, Sequence +from collections.abc import AsyncGenerator, AsyncIterator, Callable, Sequence from contextlib import AsyncExitStack from copy import copy from http import HTTPStatus -from typing import Callable, Final, Optional, Union +from typing import Final import jinja2 from fastapi import Request @@ -23,6 +23,8 @@ ResponseCodeInterpreterToolCallParam, ResponseContentPartAddedEvent, ResponseContentPartDoneEvent, + ResponseFunctionCallArgumentsDeltaEvent, + ResponseFunctionCallArgumentsDoneEvent, ResponseFunctionToolCall, ResponseFunctionWebSearch, ResponseOutputItem, @@ -49,7 +51,6 @@ from openai_harmony import Message as OpenAIHarmonyMessage from vllm import envs -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, @@ -97,8 +98,7 @@ from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import SampleLogprobs from vllm.outputs import CompletionOutput -from vllm.reasoning import ReasoningParser, ReasoningParserManager -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -109,17 +109,16 @@ class OpenAIServingResponses(OpenAIServing): def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, - request_logger: Optional[RequestLogger], - chat_template: Optional[str], + request_logger: RequestLogger | None, + chat_template: str | None, chat_template_content_format: ChatTemplateContentFormatOption, return_tokens_as_token_ids: bool = False, reasoning_parser: str = "", enable_auto_tools: bool = False, - tool_parser: Optional[str] = None, - tool_server: Optional[ToolServer] = None, + tool_parser: str | None = None, + tool_server: ToolServer | None = None, enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, enable_log_outputs: bool = False, @@ -127,11 +126,9 @@ def __init__( ) -> None: super().__init__( engine_client=engine_client, - model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - enable_force_include_usage=enable_force_include_usage, log_error_stack=log_error_stack, ) @@ -139,18 +136,9 @@ def __init__( self.chat_template_content_format: Final = chat_template_content_format self.enable_log_outputs = enable_log_outputs - self.reasoning_parser: Optional[Callable[[AnyTokenizer], ReasoningParser]] = ( - None + self.reasoning_parser = self._get_reasoning_parser( + reasoning_parser_name=reasoning_parser ) - if reasoning_parser: - try: - self.reasoning_parser = ReasoningParserManager.get_reasoning_parser( - reasoning_parser - ) - assert self.reasoning_parser is not None - except Exception as e: - raise TypeError(f"{reasoning_parser=} has not been registered") from e - self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_force_include_usage = enable_force_include_usage self.default_sampling_params = self.model_config.get_diff_sampling_param() @@ -176,7 +164,7 @@ def __init__( "the store." ) - self.use_harmony = model_config.hf_config.model_type == "gpt_oss" + self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss" if self.use_harmony: logger.warning( "For gpt-oss, we ignore --enable-auto-tool-choice " @@ -223,7 +211,7 @@ def __init__( def _validate_generator_input( self, engine_prompt: EngineTokensPrompt - ) -> Optional[ErrorResponse]: + ) -> ErrorResponse | None: """Add validations to the input to the generator here.""" if self.max_model_len <= len(engine_prompt["prompt_token_ids"]): error_message = ( @@ -239,19 +227,45 @@ def _validate_generator_input( ) return None + def _validate_create_responses_input( + self, request: ResponsesRequest + ) -> ErrorResponse | None: + if self.use_harmony and request.is_include_output_logprobs(): + return self.create_error_response( + err_type="invalid_request_error", + message="logprobs are not supported with gpt-oss models", + status_code=HTTPStatus.BAD_REQUEST, + ) + if request.store and not self.enable_store and request.background: + return self.create_error_response( + err_type="invalid_request_error", + message=( + "This vLLM engine does not support `store=True` and " + "therefore does not support the background mode. To " + "enable these features, set the environment variable " + "`VLLM_ENABLE_RESPONSES_API_STORE=1` when launching " + "the vLLM server." + ), + status_code=HTTPStatus.BAD_REQUEST, + ) + return None + async def create_responses( self, request: ResponsesRequest, - raw_request: Optional[Request] = None, - ) -> Union[ - AsyncGenerator[StreamingResponsesResponse, None], - ResponsesResponse, - ErrorResponse, - ]: + raw_request: Request | None = None, + ) -> ( + AsyncGenerator[StreamingResponsesResponse, None] + | ResponsesResponse + | ErrorResponse + ): error_check_ret = await self._check_model(request) if error_check_ret is not None: logger.error("Error with model %s", error_check_ret) return error_check_ret + maybe_validation_error = self._validate_create_responses_input(request) + if maybe_validation_error is not None: + return maybe_validation_error # If the engine is dead, raise the engine's DEAD_ERROR. # This is required for the streaming case, where we return a @@ -260,18 +274,6 @@ async def create_responses( raise self.engine_client.dead_error if request.store and not self.enable_store: - if request.background: - return self.create_error_response( - err_type="invalid_request_error", - message=( - "This vLLM engine does not support `store=True` and " - "therefore does not support the background mode. To " - "enable these features, set the environment variable " - "`VLLM_ENABLE_RESPONSES_API_STORE=1` when launching " - "the vLLM server." - ), - status_code=HTTPStatus.BAD_REQUEST, - ) # Disable the store option. # NOTE(woosuk): Although returning an error is possible, we opted # to implicitly disable store and process the request anyway, as @@ -279,12 +281,6 @@ async def create_responses( # (i.e., their request's `store=True` just because it's the default # value). request.store = False - if self.use_harmony and request.is_include_output_logprobs(): - return self.create_error_response( - err_type="invalid_request_error", - message="logprobs are not supported with gpt-oss models", - status_code=HTTPStatus.BAD_REQUEST, - ) # Handle the previous response ID. prev_response_id = request.previous_response_id @@ -369,6 +365,19 @@ async def create_responses( context = HarmonyContext(messages, available_tools) else: context = SimpleContext() + + if self.reasoning_parser is not None: + reasoning_parser = self.reasoning_parser(tokenizer) + if sampling_params.structured_outputs is None: + sampling_params.structured_outputs = StructuredOutputsParams() + struct_out = sampling_params.structured_outputs + if struct_out.all_non_structural_tag_constraints_none(): + sampling_params.structured_outputs.structural_tag = ( + reasoning_parser.prepare_structured_tag( + sampling_params.structured_outputs.structural_tag, + self.tool_server, + ) + ) generator = self._generate_with_builtin_tools( request_id=request.request_id, request_prompt=request_prompts[i], @@ -473,7 +482,7 @@ async def create_responses( async def _make_request( self, request: ResponsesRequest, - prev_response: Optional[ResponsesResponse], + prev_response: ResponsesResponse | None, tokenizer: AnyTokenizer, ): if len(request.tools) > 0: @@ -494,7 +503,7 @@ async def _make_request( def _make_request_with_harmony( self, request: ResponsesRequest, - prev_response: Optional[ResponsesResponse], + prev_response: ResponsesResponse | None, ): if request.tool_choice != "auto": raise NotImplementedError( @@ -535,8 +544,8 @@ async def responses_full_generator( model_name: str, tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, - created_time: Optional[int] = None, - ) -> Union[ErrorResponse, ResponsesResponse]: + created_time: int | None = None, + ) -> ErrorResponse | ResponsesResponse: if created_time is None: created_time = int(time.time()) @@ -601,10 +610,24 @@ async def responses_full_generator( input_tokens=num_prompt_tokens, output_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, - input_tokens_details=InputTokensDetails(cached_tokens=num_cached_tokens), + input_tokens_details=InputTokensDetails( + cached_tokens=num_cached_tokens, + input_tokens_per_turn=[ + turn.input_tokens for turn in context.all_turn_metrics + ], + cached_tokens_per_turn=[ + turn.cached_input_tokens for turn in context.all_turn_metrics + ], + ), output_tokens_details=OutputTokensDetails( reasoning_tokens=num_reasoning_tokens, tool_output_tokens=num_tool_output_tokens, + output_tokens_per_turn=[ + turn.output_tokens for turn in context.all_turn_metrics + ], + tool_output_tokens_per_turn=[ + turn.tool_output_tokens for turn in context.all_turn_metrics + ], ), ) response = ResponsesResponse.from_request( @@ -655,9 +678,9 @@ def _topk_logprobs( def _create_response_logprobs( self, token_ids: Sequence[int], - logprobs: Optional[SampleLogprobs], + logprobs: SampleLogprobs | None, tokenizer: AnyTokenizer, - top_logprobs: Optional[int] = None, + top_logprobs: int | None = None, ) -> list[Logprob]: assert logprobs is not None, "logprobs must be provided" assert len(token_ids) == len(logprobs), ( @@ -677,11 +700,13 @@ def _create_response_logprobs( token=text, logprob=max(token_logprob.logprob, -9999.0), bytes=list(text.encode("utf-8", errors="replace")), - top_logprobs=self._topk_logprobs( - logprob, top_logprobs=top_logprobs, tokenizer=tokenizer - ) - if top_logprobs - else [], + top_logprobs=( + self._topk_logprobs( + logprob, top_logprobs=top_logprobs, tokenizer=tokenizer + ) + if top_logprobs + else [] + ), ) ) return out @@ -689,9 +714,9 @@ def _create_response_logprobs( def _create_stream_response_logprobs( self, token_ids: Sequence[int], - logprobs: Optional[SampleLogprobs], + logprobs: SampleLogprobs | None, tokenizer: AnyTokenizer, - top_logprobs: Optional[int] = None, + top_logprobs: int | None = None, ) -> list[response_text_delta_event.Logprob]: lgs = self._create_response_logprobs( token_ids=token_ids, @@ -770,14 +795,16 @@ def _make_response_output_items( text=content, annotations=[], # TODO type="output_text", - logprobs=self._create_response_logprobs( - token_ids=final_output.token_ids, - logprobs=final_output.logprobs, - tokenizer=tokenizer, - top_logprobs=request.top_logprobs, - ) - if request.is_include_output_logprobs() - else None, + logprobs=( + self._create_response_logprobs( + token_ids=final_output.token_ids, + logprobs=final_output.logprobs, + tokenizer=tokenizer, + top_logprobs=request.top_logprobs, + ) + if request.is_include_output_logprobs() + else None + ), ) message = ResponseOutputMessage( id=f"msg_{random_uuid()}", @@ -806,7 +833,7 @@ def _make_response_output_items_with_harmony( def _construct_input_messages( self, request: ResponsesRequest, - prev_response: Optional[ResponsesResponse] = None, + prev_response: ResponsesResponse | None = None, ) -> list[ChatCompletionMessageParam]: messages: list[ChatCompletionMessageParam] = [] if request.instructions: @@ -843,17 +870,56 @@ def _construct_input_messages( messages.extend(request.input) # type: ignore return messages + def _construct_harmony_system_input_message( + self, request: ResponsesRequest, with_custom_tools: bool, tool_types: list[str] + ) -> OpenAIHarmonyMessage: + reasoning_effort = request.reasoning.effort if request.reasoning else None + enable_browser = ( + "web_search_preview" in tool_types + and self.tool_server is not None + and self.tool_server.has_tool("browser") + ) + enable_code_interpreter = ( + "code_interpreter" in tool_types + and self.tool_server is not None + and self.tool_server.has_tool("python") + ) + enable_container = ( + "container" in tool_types + and self.tool_server is not None + and self.tool_server.has_tool("container") + ) + sys_msg = get_system_message( + reasoning_effort=reasoning_effort, + browser_description=( + self.tool_server.get_tool_description("browser") + if enable_browser and self.tool_server is not None + else None + ), + python_description=( + self.tool_server.get_tool_description("python") + if enable_code_interpreter and self.tool_server is not None + else None + ), + container_description=( + self.tool_server.get_tool_description("container") + if enable_container and self.tool_server is not None + else None + ), + instructions=request.instructions, + with_custom_tools=with_custom_tools, + ) + return sys_msg + def _construct_input_messages_with_harmony( self, request: ResponsesRequest, - prev_response: Optional[ResponsesResponse], + prev_response: ResponsesResponse | None, ) -> list[OpenAIHarmonyMessage]: messages: list[OpenAIHarmonyMessage] = [] if prev_response is None: # New conversation. - reasoning_effort = request.reasoning.effort if request.reasoning else None tool_types = [tool.type for tool in request.tools] - # Allow the MCP Tool type to enable built in tools if the # server_label is allowlisted in # envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS @@ -864,35 +930,10 @@ def _construct_input_messages_with_harmony( and tool.server_label in envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS ): tool_types.append(tool.server_label) - enable_browser = ( - "web_search_preview" in tool_types - and self.tool_server is not None - and self.tool_server.has_tool("browser") - ) - enable_code_interpreter = ( - "code_interpreter" in tool_types - and self.tool_server is not None - and self.tool_server.has_tool("python") - ) - enable_container = ( - "container" in tool_types - and self.tool_server is not None - and self.tool_server.has_tool("container") - ) with_custom_tools = has_custom_tools(tool_types) - sys_msg = get_system_message( - reasoning_effort=reasoning_effort, - browser_description=self.tool_server.get_tool_description("browser") - if enable_browser and self.tool_server is not None - else None, - python_description=self.tool_server.get_tool_description("python") - if enable_code_interpreter and self.tool_server is not None - else None, - container_description=self.tool_server.get_tool_description("container") - if enable_container and self.tool_server is not None - else None, - instructions=request.instructions, - with_custom_tools=with_custom_tools, + + sys_msg = self._construct_harmony_system_input_message( + request, with_custom_tools, tool_types ) messages.append(sys_msg) if with_custom_tools: @@ -941,6 +982,11 @@ def _construct_input_messages_with_harmony( # to add the tool call request to prev_outputs so that the # parse_response_input can find the tool call request when # parsing the tool call output. + if ( + isinstance(response_msg, dict) + and response_msg.get("type") == "function_call" + ): + response_msg = ResponseFunctionToolCall.model_validate(response_msg) if isinstance(response_msg, ResponseFunctionToolCall): prev_outputs.append(response_msg) return messages @@ -999,7 +1045,7 @@ async def _run_background_request( async def responses_background_stream_generator( self, response_id: str, - starting_after: Optional[int] = None, + starting_after: int | None = None, ) -> AsyncGenerator[StreamingResponsesResponse, None]: if response_id not in self.event_store: raise ValueError(f"Unknown response_id: {response_id}") @@ -1024,13 +1070,13 @@ async def responses_background_stream_generator( async def retrieve_responses( self, response_id: str, - starting_after: Optional[int], - stream: Optional[bool], - ) -> Union[ - ErrorResponse, - ResponsesResponse, - AsyncGenerator[StreamingResponsesResponse, None], - ]: + starting_after: int | None, + stream: bool | None, + ) -> ( + ErrorResponse + | ResponsesResponse + | AsyncGenerator[StreamingResponsesResponse, None] + ): async with self.response_store_lock: response = self.response_store.get(response_id) @@ -1047,7 +1093,7 @@ async def retrieve_responses( async def cancel_responses( self, response_id: str, - ) -> Union[ErrorResponse, ResponsesResponse]: + ) -> ErrorResponse | ResponsesResponse: async with self.response_store_lock: response = self.response_store.get(response_id) if response is None: @@ -1095,7 +1141,7 @@ async def _process_simple_streaming_events( self, request: ResponsesRequest, sampling_params: SamplingParams, - result_generator: AsyncIterator[Optional[ConversationContext]], + result_generator: AsyncIterator[ConversationContext | None], context: ConversationContext, model_name: str, tokenizer: AnyTokenizer, @@ -1290,14 +1336,16 @@ async def _process_simple_streaming_events( output_index=current_output_index, item_id=current_item_id, delta=delta_message.content, - logprobs=self._create_stream_response_logprobs( - token_ids=output.token_ids, - logprobs=output.logprobs, - tokenizer=tokenizer, - top_logprobs=request.top_logprobs, - ) - if request.is_include_output_logprobs() - else [], + logprobs=( + self._create_stream_response_logprobs( + token_ids=output.token_ids, + logprobs=output.logprobs, + tokenizer=tokenizer, + top_logprobs=request.top_logprobs, + ) + if request.is_include_output_logprobs() + else [] + ), ) ) current_content_index += 1 @@ -1398,7 +1446,7 @@ async def _process_harmony_streaming_events( self, request: ResponsesRequest, sampling_params: SamplingParams, - result_generator: AsyncIterator[Optional[ConversationContext]], + result_generator: AsyncIterator[ConversationContext | None], context: ConversationContext, model_name: str, tokenizer: AnyTokenizer, @@ -1412,19 +1460,48 @@ async def _process_harmony_streaming_events( current_output_index = 0 current_item_id: str = "" sent_output_item_added = False - + is_first_function_call_delta = False async for ctx in result_generator: assert isinstance(ctx, StreamingHarmonyContext) if ctx.is_expecting_start(): current_output_index += 1 sent_output_item_added = False - + is_first_function_call_delta = False if len(ctx.parser.messages) > 0: previous_item = ctx.parser.messages[-1] if previous_item.recipient is not None: - # Deal with tool call here - pass + # Deal with tool call + if previous_item.recipient.startswith("functions."): + function_name = previous_item.recipient[len("functions.") :] + yield _increment_sequence_number_and_return( + ResponseFunctionCallArgumentsDoneEvent( + type="response.function_call_arguments.done", + arguments=previous_item.content[0].text, + name=function_name, + item_id=current_item_id, + output_index=current_output_index, + sequence_number=-1, + ) + ) + function_call_item = ResponseFunctionToolCall( + type="function_call", + arguments=previous_item.content[0].text, + name=function_name, + item_id=current_item_id, + output_index=current_output_index, + sequence_number=-1, + call_id=f"fc_{random_uuid()}", + status="completed", + ) + yield _increment_sequence_number_and_return( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=function_call_item, + ) + ) elif previous_item.channel == "analysis": content = ResponseReasoningTextContent( text=previous_item.content[0].text, @@ -1780,17 +1857,54 @@ async def _process_harmony_streaming_events( ), ) ) + # developer tools will be triggered on the commentary channel + # and recipient starts with "functions.TOOL_NAME" + if ( + ctx.parser.current_channel == "commentary" + and ctx.parser.current_recipient + and ctx.parser.current_recipient.startswith("functions.") + ): + if is_first_function_call_delta is False: + is_first_function_call_delta = True + fc_name = ctx.parser.current_recipient[len("functions.") :] + tool_call_item = ResponseFunctionToolCall( + name=fc_name, + type="function_call", + id=current_item_id, + call_id=f"call_{random_uuid()}", + arguments="", + status="in_progress", + ) + current_item_id = f"fc_{random_uuid()}" + yield _increment_sequence_number_and_return( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=tool_call_item, + ) + ) + else: + yield _increment_sequence_number_and_return( + ResponseFunctionCallArgumentsDeltaEvent( + item_id=current_item_id, + delta=ctx.parser.last_content_delta, + output_index=current_output_index, + sequence_number=-1, + type="response.function_call_arguments.delta", + ) + ) async def responses_stream_generator( self, request: ResponsesRequest, sampling_params: SamplingParams, - result_generator: AsyncIterator[Optional[ConversationContext]], + result_generator: AsyncIterator[ConversationContext | None], context: ConversationContext, model_name: str, tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, - created_time: Optional[int] = None, + created_time: int | None = None, ) -> AsyncGenerator[StreamingResponsesResponse, None]: # TODO: # 1. Handle disconnect @@ -1818,6 +1932,7 @@ def _increment_sequence_number_and_return( processer = self._process_harmony_streaming_events else: processer = self._process_simple_streaming_events + # TODO Hanchen make sampling params to include the structural tag initial_response = ResponsesResponse.from_request( request, diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 234a31421828..9cbfc9791819 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -3,11 +3,10 @@ import asyncio import time from collections.abc import AsyncGenerator, Mapping -from typing import Any, Optional, Union +from typing import Any from fastapi import Request -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( @@ -38,7 +37,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import make_async, merge_async_iterators +from vllm.utils.async_utils import make_async, merge_async_iterators logger = init_logger(__name__) @@ -47,15 +46,13 @@ class ServingScores(OpenAIServing): def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, - request_logger: Optional[RequestLogger], + request_logger: RequestLogger | None, log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, - model_config=model_config, models=models, request_logger=request_logger, log_error_stack=log_error_stack, @@ -66,12 +63,12 @@ async def _embedding_score( tokenizer: AnyTokenizer, texts_1: list[str], texts_2: list[str], - request: Union[RerankRequest, ScoreRequest], + request: RerankRequest | ScoreRequest, request_id: str, - tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[Union[LoRARequest, None]] = None, - trace_headers: Optional[Mapping[str, str]] = None, - ) -> Union[list[PoolingRequestOutput], ErrorResponse]: + tokenization_kwargs: dict[str, Any] | None = None, + lora_request: LoRARequest | None | None = None, + trace_headers: Mapping[str, str] | None = None, + ) -> list[PoolingRequestOutput] | ErrorResponse: input_texts = texts_1 + texts_2 engine_prompts: list[TokensPrompt] = [] @@ -128,7 +125,7 @@ async def _embedding_score( # Non-streaming response final_res_batch: list[PoolingRequestOutput] = [] - embeddings: list[Optional[PoolingRequestOutput]] = [None] * len(engine_prompts) + embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_prompts) async for i, res in result_generator: embeddings[i] = res @@ -155,11 +152,11 @@ async def _embedding_score( def _preprocess_score( self, - request: Union[RerankRequest, ScoreRequest], + request: RerankRequest | ScoreRequest, tokenizer: AnyTokenizer, tokenization_kwargs: dict[str, Any], - data_1: Union[str, ScoreContentPartParam], - data_2: Union[str, ScoreContentPartParam], + data_1: str | ScoreContentPartParam, + data_2: str | ScoreContentPartParam, ) -> tuple[str, TokensPrompt]: model_config = self.model_config @@ -179,14 +176,14 @@ def _preprocess_score( async def _cross_encoding_score( self, tokenizer: AnyTokenizer, - data_1: Union[list[str], list[ScoreContentPartParam]], - data_2: Union[list[str], list[ScoreContentPartParam]], - request: Union[RerankRequest, ScoreRequest], + data_1: list[str] | list[ScoreContentPartParam], + data_2: list[str] | list[ScoreContentPartParam], + request: RerankRequest | ScoreRequest, request_id: str, - tokenization_kwargs: Optional[dict[str, Any]] = None, - lora_request: Optional[Union[LoRARequest, None]] = None, - trace_headers: Optional[Mapping[str, str]] = None, - ) -> Union[list[PoolingRequestOutput], ErrorResponse]: + tokenization_kwargs: dict[str, Any] | None = None, + lora_request: LoRARequest | None | None = None, + trace_headers: Mapping[str, str] | None = None, + ) -> list[PoolingRequestOutput] | ErrorResponse: request_prompts: list[str] = [] engine_prompts: list[TokensPrompt] = [] @@ -262,7 +259,7 @@ async def _cross_encoding_score( result_generator = merge_async_iterators(*generators) # Non-streaming response - final_res_batch: list[Optional[PoolingRequestOutput]] = [None] * len( + final_res_batch: list[PoolingRequestOutput | None] = [None] * len( engine_prompts ) @@ -273,12 +270,12 @@ async def _cross_encoding_score( async def _run_scoring( self, - data_1: Union[list[str], str, ScoreMultiModalParam], - data_2: Union[list[str], str, ScoreMultiModalParam], - request: Union[ScoreRequest, RerankRequest], + data_1: list[str] | str | ScoreMultiModalParam, + data_2: list[str] | str | ScoreMultiModalParam, + request: ScoreRequest | RerankRequest, request_id: str, - raw_request: Optional[Request] = None, - ) -> Union[list[PoolingRequestOutput], ErrorResponse]: + raw_request: Request | None = None, + ) -> list[PoolingRequestOutput] | ErrorResponse: lora_request = self._maybe_get_adapters(request) tokenizer = await self.engine_client.get_tokenizer() @@ -342,8 +339,8 @@ async def _run_scoring( async def create_score( self, request: ScoreRequest, - raw_request: Optional[Request] = None, - ) -> Union[ScoreResponse, ErrorResponse]: + raw_request: Request | None = None, + ) -> ScoreResponse | ErrorResponse: """ Score API similar to Sentence Transformers cross encoder @@ -380,8 +377,8 @@ async def create_score( return self.create_error_response(str(e)) async def do_rerank( - self, request: RerankRequest, raw_request: Optional[Request] = None - ) -> Union[RerankResponse, ErrorResponse]: + self, request: RerankRequest, raw_request: Request | None = None + ) -> RerankResponse | ErrorResponse: """ Rerank API based on JinaAI's rerank API; implements the same API interface. Designed for compatibility with off-the-shelf @@ -471,7 +468,7 @@ def request_output_to_rerank_response( final_res_batch: list[PoolingRequestOutput], request_id: str, model_name: str, - documents: Union[list[str], ScoreMultiModalParam], + documents: list[str] | ScoreMultiModalParam, top_n: int, ) -> RerankResponse: """ diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 7b192dcd6c86..39aae0cd0495 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any, Final, Optional, Union +from typing import Any, Final import jinja2 from fastapi import Request -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger @@ -32,18 +31,16 @@ class OpenAIServingTokenization(OpenAIServing): def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, - request_logger: Optional[RequestLogger], - chat_template: Optional[str], + request_logger: RequestLogger | None, + chat_template: str | None, chat_template_content_format: ChatTemplateContentFormatOption, trust_request_chat_template: bool = False, log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, - model_config=model_config, models=models, request_logger=request_logger, log_error_stack=log_error_stack, @@ -57,7 +54,7 @@ async def create_tokenize( self, request: TokenizeRequest, raw_request: Request, - ) -> Union[TokenizeResponse, ErrorResponse]: + ) -> TokenizeResponse | ErrorResponse: error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret @@ -132,7 +129,7 @@ async def create_detokenize( self, request: DetokenizeRequest, raw_request: Request, - ) -> Union[DetokenizeResponse, ErrorResponse]: + ) -> DetokenizeResponse | ErrorResponse: error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret @@ -158,7 +155,7 @@ async def create_detokenize( async def get_tokenizer_info( self, - ) -> Union[TokenizerInfoResponse, ErrorResponse]: + ) -> TokenizerInfoResponse | ErrorResponse: """Get comprehensive tokenizer information.""" try: tokenizer = await self.engine_client.get_tokenizer() @@ -174,7 +171,7 @@ def _build_render_config(self, request: TokenizeRequest) -> RenderConfig: @dataclass class TokenizerInfo: tokenizer: AnyTokenizer - chat_template: Optional[str] + chat_template: str | None def to_dict(self) -> dict[str, Any]: """Return the tokenizer configuration.""" diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 6cc31c1e08d3..33da7034afab 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -1,11 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import AsyncGenerator -from typing import Optional, Union from fastapi import Request -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( @@ -34,26 +32,26 @@ class OpenAIServingTranscription(OpenAISpeechToText): def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, - request_logger: Optional[RequestLogger], + request_logger: RequestLogger | None, return_tokens_as_token_ids: bool = False, log_error_stack: bool = False, + enable_force_include_usage: bool = False, ): super().__init__( engine_client=engine_client, - model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, task_type="transcribe", log_error_stack=log_error_stack, + enable_force_include_usage=enable_force_include_usage, ) async def create_transcription( self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request - ) -> Union[TranscriptionResponse, AsyncGenerator[str, None], ErrorResponse]: + ) -> TranscriptionResponse | AsyncGenerator[str, None] | ErrorResponse: """Transcription API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/audio/createTranscription @@ -95,26 +93,26 @@ class OpenAIServingTranslation(OpenAISpeechToText): def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, - request_logger: Optional[RequestLogger], + request_logger: RequestLogger | None, return_tokens_as_token_ids: bool = False, log_error_stack: bool = False, + enable_force_include_usage: bool = False, ): super().__init__( engine_client=engine_client, - model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, task_type="translate", log_error_stack=log_error_stack, + enable_force_include_usage=enable_force_include_usage, ) async def create_translation( self, audio_data: bytes, request: TranslationRequest, raw_request: Request - ) -> Union[TranslationResponse, AsyncGenerator[str, None], ErrorResponse]: + ) -> TranslationResponse | AsyncGenerator[str, None] | ErrorResponse: """Translation API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/audio/createTranslation diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index 779498b308e8..46139642c50c 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -4,15 +4,14 @@ import io import math import time -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable from functools import cached_property -from typing import Callable, Literal, Optional, TypeVar, Union, cast +from typing import Literal, TypeAlias, TypeVar, cast import numpy as np from fastapi import Request import vllm.envs as envs -from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( @@ -33,14 +32,14 @@ from vllm.logger import init_logger from vllm.model_executor.models import SupportsTranscription from vllm.outputs import RequestOutput -from vllm.utils import PlaceholderModule +from vllm.utils.import_utils import PlaceholderModule try: import librosa except ImportError: librosa = PlaceholderModule("librosa") # type: ignore[assignment] -SpeechToTextResponse = Union[TranscriptionResponse, TranslationResponse] +SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse T = TypeVar("T", bound=SpeechToTextResponse) logger = init_logger(__name__) @@ -53,17 +52,16 @@ class OpenAISpeechToText(OpenAIServing): def __init__( self, engine_client: EngineClient, - model_config: ModelConfig, models: OpenAIServingModels, *, - request_logger: Optional[RequestLogger], + request_logger: RequestLogger | None, return_tokens_as_token_ids: bool = False, task_type: Literal["transcribe", "translate"] = "transcribe", log_error_stack: bool = False, + enable_force_include_usage: bool = False, ): super().__init__( engine_client=engine_client, - model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, @@ -74,9 +72,11 @@ def __init__( self.task_type = task_type self.asr_config = self.model_cls.get_speech_to_text_config( - model_config, task_type + self.model_config, task_type ) + self.enable_force_include_usage = enable_force_include_usage + self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB if self.default_sampling_params: @@ -143,7 +143,7 @@ async def _create_speech_to_text( raw_request: Request, response_class: type[T], stream_generator_method: Callable[..., AsyncGenerator[str, None]], - ) -> Union[T, AsyncGenerator[str, None], ErrorResponse]: + ) -> T | AsyncGenerator[str, None] | ErrorResponse: """Base method for speech-to-text operations like transcription and translation.""" error_check_ret = await self._check_model(request) @@ -184,9 +184,7 @@ async def _create_speech_to_text( logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) - list_result_generator: Optional[list[AsyncGenerator[RequestOutput, None]]] = ( - None - ) + list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None try: # Unlike most decoder-only models, whisper generation length is not # constrained by the size of the input audio, which is mapped to a @@ -255,13 +253,10 @@ async def _speech_to_text_stream_generator( request_metadata: RequestResponseMetadata, audio_duration_s: float, chunk_object_type: Literal["translation.chunk", "transcription.chunk"], - response_stream_choice_class: Union[ - type[TranscriptionResponseStreamChoice], - type[TranslationResponseStreamChoice], - ], - stream_response_class: Union[ - type[TranscriptionStreamResponse], type[TranslationStreamResponse] - ], + response_stream_choice_class: type[TranscriptionResponseStreamChoice] + | type[TranslationResponseStreamChoice], + stream_response_class: type[TranscriptionStreamResponse] + | type[TranslationStreamResponse], ) -> AsyncGenerator[str, None]: created_time = int(time.time()) model_name = request.model @@ -269,9 +264,7 @@ async def _speech_to_text_stream_generator( completion_tokens = 0 num_prompt_tokens = 0 - include_usage = ( - request.stream_include_usage if request.stream_include_usage else False - ) + include_usage = self.enable_force_include_usage or request.stream_include_usage include_continuous_usage = ( request.stream_continuous_usage_stats if include_usage and request.stream_continuous_usage_stats diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 2c5a0a6af23f..a72772f59cf2 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -4,6 +4,7 @@ from .abstract_tool_parser import ToolParser, ToolParserManager from .deepseekv3_tool_parser import DeepSeekV3ToolParser from .deepseekv31_tool_parser import DeepSeekV31ToolParser +from .ernie45_tool_parser import Ernie45ToolParser from .glm4_moe_tool_parser import Glm4MoeModelToolParser from .granite_20b_fc_tool_parser import Granite20bFCToolParser from .granite_tool_parser import GraniteToolParser @@ -17,6 +18,7 @@ from .longcat_tool_parser import LongcatFlashToolParser from .minimax_tool_parser import MinimaxToolParser from .mistral_tool_parser import MistralToolParser +from .olmo3_tool_parser import Olmo3PythonicToolParser from .openai_tool_parser import OpenAIToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .pythonic_tool_parser import PythonicToolParser @@ -42,7 +44,9 @@ "Phi4MiniJsonToolParser", "DeepSeekV3ToolParser", "DeepSeekV31ToolParser", + "Ernie45ToolParser", "xLAMToolParser", + "Olmo3PythonicToolParser", "MinimaxToolParser", "KimiK2ToolParser", "HunyuanA13BToolParser", diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index e6ee2fa777f8..212326fdafb1 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -2,18 +2,22 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import cached_property -from typing import Callable, Optional, Union from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, DeltaMessage, ExtractedToolCallInformation, ) +from vllm.entrypoints.openai.tool_parsers.utils import get_json_schema_from_tools from vllm.logger import init_logger +from vllm.sampling_params import ( + StructuredOutputsParams, +) from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import import_from_path, is_list_of +from vllm.utils.collection_utils import is_list_of +from vllm.utils.import_utils import import_from_path logger = init_logger(__name__) @@ -44,6 +48,18 @@ def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionReques """ Static method that used to adjust the request parameters. """ + if not request.tools: + return request + json_schema_from_tool = get_json_schema_from_tools( + tool_choice=request.tool_choice, tools=request.tools + ) + # Set structured output params for tool calling + if json_schema_from_tool is not None: + if request.structured_outputs is None: + request.structured_outputs = StructuredOutputsParams() + # tool_choice: "Forced Function" or "required" will override + # structured output json settings to make tool calling work correctly + request.structured_outputs.json = json_schema_from_tool return request def extract_tool_calls( @@ -69,7 +85,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: """ Instance method that should be implemented for extracting tool calls from an incomplete response; for use when handling tool calls and @@ -101,7 +117,7 @@ def get_tool_parser(cls, name) -> type: def _register_module( cls, module: type, - module_name: Optional[Union[str, list[str]]] = None, + module_name: str | list[str] | None = None, force: bool = True, ) -> None: if not issubclass(module, ToolParser): @@ -123,10 +139,10 @@ def _register_module( @classmethod def register_module( cls, - name: Optional[Union[str, list[str]]] = None, + name: str | list[str] | None = None, force: bool = True, - module: Union[type, None] = None, - ) -> Union[type, Callable]: + module: type | None = None, + ) -> type | Callable: """ Register module with the given name or name list. it can be used as a decoder(with module as None) or normal function(with module as not diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py index c6e8f1686e24..14fd5cf0941c 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import Union import regex as re @@ -129,7 +128,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) # check to see if we should be streaming a tool call - is there a @@ -272,7 +271,7 @@ def extract_tool_calls_streaming( if not self.current_tool_name_sent: if current_tool_call is None: return None - function_name: Union[str, None] = current_tool_call.get("name") + function_name: str | None = current_tool_call.get("name") if function_name: self.current_tool_name_sent = True return DeltaMessage( diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py index e8a5d2e6dc13..b256560fb4be 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import Union import regex as re @@ -129,7 +128,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) # check to see if we should be streaming a tool call - is there a @@ -272,7 +271,7 @@ def extract_tool_calls_streaming( if not self.current_tool_name_sent: if current_tool_call is None: return None - function_name: Union[str, None] = current_tool_call.get("name") + function_name: str | None = current_tool_call.get("name") if function_name: self.current_tool_name_sent = True return DeltaMessage( diff --git a/vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py new file mode 100644 index 000000000000..e4696334eb13 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence + +import regex as re + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, + ToolParserManager, +) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("ernie45") +class Ernie45ToolParser(ToolParser): + def __init__(self, tokenizer: AnyTokenizer): + """ + Ernie thinking model format: + abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n + """ + super().__init__(tokenizer) + self.current_tool_name_sent = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id = -1 + self.streamed_args_for_tool: list[str] = [] + self.think_end_token = "</think>" + self.response_start_token: str = "<response>" + self.response_end_token: str = "</response>" + self.tool_call_start_token = "<tool_call>" + self.tool_call_end_token = "</tool_call>" + self.tool_calls_start_token = self.tool_call_start_token + self.newline_token: str = "<0x0A>" + + self.tool_call_regex = re.compile( + r"<tool_call>\s*(?P<json>\{.*?\})\s*</tool_call>", re.DOTALL + ) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction." + ) + + self.think_end_token_id = self.vocab.get(self.think_end_token) + self.response_start_token_id = self.vocab.get(self.response_start_token) + self.response_end_token_id = self.vocab.get(self.response_end_token) + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + self.newline_token_id = self.vocab.get(self.newline_token) + self.parser_token_ids = [ + self.think_end_token_id, + self.response_start_token_id, + self.response_end_token_id, + ] + + self._buffer = "" + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + # sanity check; avoid unnecessary processing + if self.tool_calls_start_token not in model_output: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + else: + try: + tool_call_json_list = self.tool_call_regex.findall(model_output) + + tool_calls = [] + for tool_call_json in tool_call_json_list: + tool_call_dict = json.loads(tool_call_json) + args_str = json.dumps( + tool_call_dict.get("arguments", {}), ensure_ascii=False + ) + tool_calls.append( + ToolCall( + type="function", + function=FunctionCall( + name=tool_call_dict.get("name", ""), + arguments=args_str, + ), + ) + ) + + content = model_output[ + : model_output.find(self.tool_calls_start_token) + ].rstrip("\n") + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + self._buffer += delta_text + cur_text = self._buffer + start_idx = cur_text.find(self.tool_call_start_token) + if start_idx == -1: + self._buffer = "" + # At least one toolcall has been completed + if self.current_tool_id > 0: + cur_text = "" + if self.current_tool_id == -1 and all( + token_id == self.newline_token_id for token_id in previous_token_ids + ): + cur_text = cur_text.strip("\n") + + # handle <response> </response> when tool_call is not triggered + # cur_text === delta_text + content = cur_text + if self.response_start_token_id in delta_token_ids: + content = content.lstrip("\n") + response_start_idx = content.find(self.response_start_token) + content = content[response_start_idx + len(self.response_start_token) :] + # if have </response>, remove it + response_end_idx = content.rfind(self.response_end_token) + if response_end_idx != -1: + content = content[:response_end_idx] + elif self.response_end_token_id in delta_token_ids: + response_end_idx = content.rfind(self.response_end_token) + content = content[:response_end_idx] + # remove \n after </think> or <response> or </response> + if ( + len(previous_token_ids) > 0 + and previous_token_ids[-1] in self.parser_token_ids + ) and ( + len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id + ): + content = content.lstrip("\n") + + return DeltaMessage(content=content if content else None) + logger.debug("cur_text = %s", cur_text) + end_idx = cur_text.find(self.tool_call_end_token) + if end_idx != -1: + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [] + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + extracted_tool_calls = self.extract_tool_calls( + cur_text[: end_idx + len(self.tool_call_end_token)], request + ) + + if len(extracted_tool_calls.tool_calls) == 0: + logger.warning("Failed to extract any tool calls.") + return None + tool_call = extracted_tool_calls.tool_calls[0] + self.prev_tool_call_arr[self.current_tool_id] = { + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + self.streamed_args_for_tool[self.current_tool_id] = ( + tool_call.function.arguments + ) + delta = DeltaMessage( + content=extracted_tool_calls.content, + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + id=tool_call.id, + type=tool_call.type, + function=DeltaFunctionCall( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + ) + ], + ) + self.current_tool_id += 1 + self._buffer = cur_text[end_idx + len(self.tool_call_end_token) :] + return delta + + self._buffer = cur_text[start_idx:] + content = cur_text[:start_idx].rstrip("\n") + return DeltaMessage(content=content if content else None) diff --git a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py index 1d7d7d3f8629..5081b38240ce 100644 --- a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py @@ -4,7 +4,7 @@ import ast import json from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any import regex as re @@ -66,7 +66,7 @@ def extract_tool_calls( def _is_string_type( tool_name: str, arg_name: str, - tools: Optional[list[ChatCompletionToolsParam]], + tools: list[ChatCompletionToolsParam] | None, ) -> bool: if tools is None: return False @@ -144,7 +144,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: self._buffer += delta_text cur_text = self._buffer start_idx = cur_text.find(self.tool_call_start_token) diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py index c42b358b1e34..c5246685f407 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py @@ -4,7 +4,6 @@ import json from collections.abc import Sequence from json import JSONDecoder -from typing import Union import partial_json_parser import regex as re @@ -121,7 +120,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: if len(current_text) < len(self.bot_token) and self.bot_token.startswith( current_text ): diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py index 989973923ae5..cc1f50034235 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -3,7 +3,6 @@ import json from collections.abc import Sequence -from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow @@ -108,7 +107,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: start_idx = consume_space(0, current_text) if current_text[start_idx:].startswith(self.bot_token): start_idx = consume_space(start_idx + len(self.bot_token), current_text) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 4529eb51796e..6332de42f424 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -3,7 +3,6 @@ import json from collections.abc import Sequence -from typing import Union import partial_json_parser import regex as re @@ -113,6 +112,7 @@ def tool_call_delta_buffer(self, delta_text: str): return delta_text def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + request = super().adjust_request(request) if request.tools and request.tool_choice != "none": # do not skip special tokens because the tool_call tokens are # marked "special" in some models. Since they are skipped @@ -181,7 +181,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: # 1. All tokens are parsed based on _text, not token_ids. # 2. All incoming text data is processed by the tool_call_delta_buffer # function for buffering before being used for parsing. @@ -333,7 +333,7 @@ def extract_tool_calls_streaming( if not self.current_tool_name_sent: if current_tool_call is None: return None - function_name: Union[str, None] = current_tool_call.get("name") + function_name: str | None = current_tool_call.get("name") if function_name: self.current_tool_name_sent = True return DeltaMessage( diff --git a/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py index 1855d69adb21..b32e6e39b3e5 100644 --- a/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py @@ -4,7 +4,7 @@ import json from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any import regex as re @@ -73,7 +73,7 @@ def __init__(self, tokenizer: AnyTokenizer): def preprocess_model_output( self, model_output: str - ) -> tuple[Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None]: # find the location tool call for match in self.answer_tool_calls_pattern.finditer(model_output): start, end = match.span() @@ -176,7 +176,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: """ Extract tool calls for streaming mode. """ diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 9adaea297b05..c87bab4353b5 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -3,7 +3,6 @@ import json from collections.abc import Sequence -from typing import Union import partial_json_parser from partial_json_parser.core.options import Allow @@ -36,6 +35,7 @@ def __init__(self, tokenizer: AnyTokenizer): self.position = 0 def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + request = super().adjust_request(request) if request.tools and request.tool_choice != "none": # do not skip special tokens because internlm use the special # tokens to indicate the start and end of the tool calls @@ -59,7 +59,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: if "<|action_start|>" not in current_text: self.position = len(current_text) return DeltaMessage(content=delta_text) diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py index 1ae3e0da3351..21ee2b762cd0 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -3,7 +3,6 @@ import json from collections.abc import Sequence -from typing import Union import partial_json_parser import regex as re @@ -69,6 +68,7 @@ def __init__(self, tokenizer: AnyTokenizer): ) def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + request = super().adjust_request(request) if request.tools and request.tool_choice != "none": # do not skip special tokens because jamba use the special # tokens to indicate the start and end of the tool calls @@ -129,7 +129,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: # if the tool call token is not in the tokens generated so far, append # output to contents since it's not a tool if self.tool_calls_start_token not in current_text: @@ -190,7 +190,7 @@ def extract_tool_calls_streaming( # auto-generated due to JSON completions, but wasn't # streamed to the client yet. if self.current_tool_id >= 0: - diff: Union[str, None] = current_tool_call.get("arguments") + diff: str | None = current_tool_call.get("arguments") if diff: diff = json.dumps(diff, ensure_ascii=False).replace( diff --git a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py index a2eff21a4466..98a52ddd60d6 100644 --- a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py @@ -3,7 +3,6 @@ # code modified from deepseekv3_tool_parser.py from collections.abc import Sequence -from typing import Union import regex as re @@ -131,7 +130,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) # check to see if we should be streaming a tool call - is there a @@ -278,7 +277,7 @@ def extract_tool_calls_streaming( if not self.current_tool_name_sent: if current_tool_call is None: return None - function_name: Union[str, None] = current_tool_call.get("name") + function_name: str | None = current_tool_call.get("name") tool_id = current_tool_call.get("id") if function_name: self.current_tool_name_sent = True diff --git a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py index 162675efbc9a..dd622b69525d 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py @@ -3,7 +3,7 @@ import ast import json from collections.abc import Sequence -from typing import Any, Union +from typing import Any import regex as re from transformers import PreTrainedTokenizerBase @@ -128,7 +128,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: if not current_text.startswith("[") and not current_text.startswith( "<|python_start|>" ): @@ -245,7 +245,7 @@ def _handle_single_tool(call: ast.Call) -> ToolCall: ) -def _make_valid_python(text: str) -> Union[tuple[str, str], None]: +def _make_valid_python(text: str) -> tuple[str, str] | None: bracket_stack = [] for index, char in enumerate(text): if char in {"[", "(", "{"}: @@ -317,7 +317,7 @@ def _make_valid_python(text: str) -> Union[tuple[str, str], None]: def _compute_tool_delta( previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str -) -> Union[DeltaToolCall, None]: +) -> DeltaToolCall | None: new_call_args = new_call.function.arguments if withheld_suffix: assert new_call_args.endswith(withheld_suffix) diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 4d5ef5ed64aa..8c7b3cefb200 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -3,7 +3,6 @@ import json from collections.abc import Sequence -from typing import Union import partial_json_parser import regex as re @@ -134,7 +133,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: if not ( current_text.startswith(self.bot_token) or current_text.startswith("{") ): diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py index 0b83fd237a6a..4b12bf68b367 100644 --- a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py @@ -3,7 +3,7 @@ import json from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any import regex as re @@ -509,7 +509,7 @@ def _extract_tool_args(self, tool_content: str, args_match: re.Match[str]) -> st def _get_current_tool_content( self, text: str, tool_index: int - ) -> tuple[Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None]: """ Get the content of a specific tool by index. @@ -545,7 +545,7 @@ def _get_current_tool_content( def _handle_tool_name_streaming( self, tool_content: str, tool_count: int - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: """ Handle streaming of tool names. @@ -595,7 +595,7 @@ def _handle_tool_name_streaming( def _handle_tool_args_streaming( self, tool_content: str, tool_count: int - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: """ Handle streaming of tool arguments. @@ -702,7 +702,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: self._update_thinking_state(current_text) if self.in_thinking_tag: @@ -776,7 +776,7 @@ def extract_tool_calls_streaming( ) return None - def _find_tool_start_outside_thinking(self, current_text: str) -> Optional[int]: + def _find_tool_start_outside_thinking(self, current_text: str) -> int | None: """ Find the start position of tool calls outside of thinking tags. @@ -809,7 +809,7 @@ def _find_tool_start_outside_thinking(self, current_text: str) -> Optional[int]: def _extract_content_before_tools( self, current_text: str, delta_text: str, tool_start: int - ) -> Optional[str]: + ) -> str | None: """ Extract content that appears before tool calls. diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index b3b8960276bc..dbdf0085367b 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -5,7 +5,6 @@ from collections.abc import Sequence from random import choices from string import ascii_letters, digits -from typing import Union import partial_json_parser import regex as re @@ -95,6 +94,7 @@ def __init__(self, tokenizer: AnyTokenizer): ) def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + request = super().adjust_request(request) if ( not isinstance(self.model_tokenizer, MistralTokenizer) and request.tools @@ -194,7 +194,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: # if the tool call token is not in the tokens generated so far, append # output to contents since it's not a tool if self.bot_token not in current_text: @@ -252,7 +252,7 @@ def extract_tool_calls_streaming( # auto-generated due to JSON completions, but wasn't # streamed to the client yet. if self.current_tool_id >= 0: - diff: Union[str, None] = current_tool_call.get("arguments") + diff: str | None = current_tool_call.get("arguments") if diff: diff = json.dumps(diff, ensure_ascii=False).replace( diff --git a/vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py new file mode 100644 index 000000000000..ed5633aac02d --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py @@ -0,0 +1,368 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast +import json +from collections.abc import Sequence +from typing import Any + +import regex as re +from transformers import PreTrainedTokenizerBase + +import vllm.envs as envs +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, + ToolParserManager, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class _UnexpectedAstError(Exception): + pass + + +@ToolParserManager.register_module("olmo3") +class Olmo3PythonicToolParser(ToolParser): + """ + Tool call parser for Olmo 3 models that produce tool calls as + newline-separated pythonic strings. + Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set + Code copied from pythonic_tool_parser.py and updated to handle + - newline separated pythonic tool calls. + - argument values being null/true/false instead of Pythonic literals. + """ + + # TODO(mdepinet): Possible future improvements: + # 1. Support text + tools separated by either <|python_tag|> or \n\n + # 2. Support tools outside of a list (or separated by a semicolon). + # This depends on item 1 for consistent streaming. + # Neither of these are necessary for e.g. ToolACE, but both would help make + # Llama3.2 models more reliable. + + TOOL_CALL_REGEX = re.compile( + r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]", + re.DOTALL, + ) + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + # Rename for readability. This is NOT a tool id. + @property + def current_tool_index(self) -> int: + return self.current_tool_id + + @current_tool_index.setter + def current_tool_index(self, value: int) -> None: + self.current_tool_id = value + + def extract_tool_calls( + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. + """ + original_model_output = model_output + # Remove xml tags. + match = re.search( + r"<function_calls>(.*?)</function_calls>", model_output, re.DOTALL + ) + if match: + model_output = match.group(1).strip() + # Make the newline separated function calls into a list. + model_output = ", ".join( + [line.strip() for line in model_output.splitlines() if line.strip()] + ) + model_output = f"[{model_output}]" + + is_tool_call_pattern = False + try: + is_tool_call_pattern = ( + self.TOOL_CALL_REGEX.match( + model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS + ) + is not None + ) + except TimeoutError: + logger.warning("Regex timeout occurred when matching tool call pattern.") + logger.debug( + "Regex timeout occurred when matching user input: %s", model_output + ) + + if not is_tool_call_pattern: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=original_model_output + ) + + try: + module = ast.parse(model_output) + parsed = getattr(module.body[0], "value", None) + if isinstance(parsed, ast.List) and all( + isinstance(e, ast.Call) for e in parsed.elts + ): + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=[ + _handle_single_tool(e) # type: ignore + for e in parsed.elts + ], + content=None, + ) + else: + raise _UnexpectedAstError( + "Tool output must be a list of function calls" + ) + except Exception: + logger.exception("Error in extracting tool call from response.") + # Treat as regular text + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=original_model_output + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + # All function calls start with the <function_calls> tag. + # But since this is streaming, we may have seen only part of the tag. + if not current_text.startswith("<"): + return DeltaMessage(content=delta_text) + + try: + # Remove xml tags. + if current_text.startswith("<function_calls>"): + current_text = current_text[len("<function_calls>") :] + if current_text.endswith("</function_calls>"): + current_text = current_text[: -len("</function_calls>")] + + valid_and_added_text = _make_valid_python(current_text) + if valid_and_added_text is None: + return None + valid_text, added_text = valid_and_added_text + + # Make the newline separated function calls into a list. + valid_text = ", ".join( + [line.strip() for line in valid_text.splitlines() if line.strip()] + ) + valid_text = f"[{valid_text}]" + module = ast.parse(valid_text) + parsed = getattr(module.body[0], "value", None) + if not isinstance(parsed, ast.List) or not all( + isinstance(e, ast.Call) for e in parsed.elts + ): + raise _UnexpectedAstError( + "Tool output must be a sequence of newline-separated calls" + ) + tool_calls = [ + _handle_single_tool(e) # type: ignore + for e in parsed.elts + ] + + tool_deltas = [] + for index, new_call in enumerate(tool_calls): + if index < self.current_tool_index: + continue + + self.current_tool_index = index + if len(self.streamed_args_for_tool) == index: + self.streamed_args_for_tool.append("") + + new_call_complete = index < len(tool_calls) - 1 or ")" not in added_text + if new_call_complete: + self.current_tool_index += 1 + + withheld_suffix = added_text[:-1] if not new_call_complete else "" + if not new_call_complete and added_text[-1] == ")": + # Function call is incomplete. Withhold the closing bracket. + withheld_suffix = withheld_suffix + "}" + # Strings get single quotes in the model-produced string. + # JSON requires double quotes. + withheld_suffix = withheld_suffix.replace("'", '"') + delta = _compute_tool_delta( + self.streamed_args_for_tool[index], new_call, index, withheld_suffix + ) + + if delta is not None: + tool_deltas.append(delta) + if ( + delta.function is not None + and delta.function.arguments is not None + ): + self.streamed_args_for_tool[index] += delta.function.arguments + + # HACK: serving_chat.py inspects the internal state of tool parsers + # when determining its final streaming delta, automatically + # adding autocompleted JSON. + # These two lines avoid that nonsense while ensuring finish_reason + # is set to tool_calls when at least one tool is called. + if tool_deltas and not self.prev_tool_call_arr: + self.prev_tool_call_arr = [{"arguments": {}}] + + if tool_deltas: + return DeltaMessage(tool_calls=tool_deltas) + elif not added_text and self.current_tool_id > 0: + # Return an empty DeltaMessage once the tool calls are all done + # so that finish_reason gets set. + return DeltaMessage(content="") + else: + return None + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction error" + ) + return None + + +def _get_parameter_value(val: ast.expr) -> Any: + if isinstance(val, ast.Constant): + return val.value + elif isinstance(val, ast.Dict): + if not all(isinstance(k, ast.Constant) for k in val.keys): + raise _UnexpectedAstError("Dict tool call arguments must have literal keys") + return { + k.value: _get_parameter_value(v) # type: ignore + for k, v in zip(val.keys, val.values) + } + elif isinstance(val, ast.List): + return [_get_parameter_value(v) for v in val.elts] + # The model may return function calls where the values are null/true/false + # because the system prompt has API description in json. + elif isinstance(val, ast.Name) and val.id in ["null", "true", "false"]: + if val.id == "null": + return None + elif val.id == "true": + return True + elif val.id == "false": + return False + else: + raise _UnexpectedAstError("Tool call arguments must be literals") + + +def _handle_single_tool(call: ast.Call) -> ToolCall: + if not isinstance(call.func, ast.Name): + raise _UnexpectedAstError("Invalid tool call name") + function_name = call.func.id + arguments = {} + for keyword in call.keywords: + arguments[keyword.arg] = _get_parameter_value(keyword.value) + return ToolCall( + type="function", + function=FunctionCall( + name=function_name, arguments=json.dumps(arguments, ensure_ascii=False) + ), + ) + + +def _make_valid_python(text: str) -> tuple[str, str] | None: + bracket_stack = [] + for index, char in enumerate(text): + if char in {"[", "(", "{"}: + bracket_stack.append(char) + elif char == "]": + if not bracket_stack or bracket_stack.pop() != "[": + raise _UnexpectedAstError("Mismatched square brackets") + elif char == ")": + if not bracket_stack or bracket_stack.pop() != "(": + raise _UnexpectedAstError("Mismatched parentheses") + elif char == "}": + if not bracket_stack or bracket_stack.pop() != "{": + raise _UnexpectedAstError("Mismatched curly braces") + elif char in {"'", '"'}: + if bracket_stack and bracket_stack[-1] == char: + if index > 0 and text[index - 1] == "\\": + # Treat an escaped quote as a regular character + pass + else: + bracket_stack.pop() + elif bracket_stack and bracket_stack[-1] in {"'", '"'}: + # Double quote within a single quote string or vice versa. + pass + else: + bracket_stack.append(char) + + text = text.rstrip() + if text.endswith("=") or text.endswith(":"): + # Since we have no type information for this property/parameter value, + # we can't fill in a valid value. + return None + if bracket_stack and bracket_stack[-1] == "{": + trailing_dict_text = text[: text.rfind("{")] + num_keys = trailing_dict_text.count(":") + num_values = trailing_dict_text.count(",") + if num_keys <= num_values: + return None # Incomplete property name within parameter value + if bracket_stack and bracket_stack[-1] == "(": + trailing_params_text = text[: text.rfind("(")] + num_full_param_names = trailing_params_text.count("=") + num_full_param_values = trailing_params_text.count(",") + if num_full_param_names <= num_full_param_values: + return None # Incomplete parameter name + if text.endswith(","): + text = text[:-1] + if ( + bracket_stack + and bracket_stack[-1] == "[" + and not text.endswith("[") + and not text.endswith(")") + ): + return None # Incomplete function name + + added_text = "" + for char in reversed(bracket_stack): + if char == "[": + added_text += "]" + elif char == "(": + added_text += ")" + elif char == "{": + added_text += "}" + elif char == "'": + added_text += "'" + elif char == '"': + added_text += '"' + + return text + added_text, added_text + + +def _compute_tool_delta( + previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str +) -> DeltaToolCall | None: + new_call_args = new_call.function.arguments + if withheld_suffix: + assert new_call_args.endswith(withheld_suffix) + new_call_args = new_call_args[: -len(withheld_suffix)] + if not previously_sent_args: + return DeltaToolCall( + id=new_call.id, + type="function", + index=index, + function=DeltaFunctionCall( + name=new_call.function.name, + arguments=new_call_args, + ), + ) + + arg_diff = new_call_args[len(previously_sent_args) :] + return ( + DeltaToolCall( + id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff) + ) + if arg_diff + else None + ) diff --git a/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py index 8d7cbbfba649..f44876943ac2 100644 --- a/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import json from collections.abc import Sequence from typing import TYPE_CHECKING @@ -22,13 +20,15 @@ if TYPE_CHECKING: from vllm.transformers_utils.tokenizer import AnyTokenizer +else: + AnyTokenizer = object logger = init_logger(__name__) @ToolParserManager.register_module("openai") class OpenAIToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): + def __init__(self, tokenizer: "AnyTokenizer"): super().__init__(tokenizer) def extract_tool_calls( diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py index 114987e5600b..a8387ba1494d 100644 --- a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py @@ -3,7 +3,7 @@ import json from collections.abc import Sequence -from typing import Any, Optional +from typing import Any import regex as re from transformers import PreTrainedTokenizerBase @@ -118,5 +118,5 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Optional[DeltaMessage]: + ) -> DeltaMessage | None: return None diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py index 272068a6f0ac..4945e7b5ab20 100644 --- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -4,7 +4,7 @@ import ast import json from collections.abc import Sequence -from typing import Any, Union +from typing import Any import regex as re from transformers import PreTrainedTokenizerBase @@ -124,7 +124,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: if not current_text.startswith("["): return DeltaMessage(content=delta_text) @@ -236,7 +236,7 @@ def _handle_single_tool(call: ast.Call) -> ToolCall: ) -def _make_valid_python(text: str) -> Union[tuple[str, str], None]: +def _make_valid_python(text: str) -> tuple[str, str] | None: bracket_stack = [] for index, char in enumerate(text): if char in {"[", "(", "{"}: @@ -308,7 +308,7 @@ def _make_valid_python(text: str) -> Union[tuple[str, str], None]: def _compute_tool_delta( previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str -) -> Union[DeltaToolCall, None]: +) -> DeltaToolCall | None: new_call_args = new_call.function.arguments if withheld_suffix: assert new_call_args.endswith(withheld_suffix) diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py index a41ca30bf527..ad56972e6387 100644 --- a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py @@ -4,7 +4,7 @@ import json import uuid from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any import regex as re @@ -36,7 +36,7 @@ def __init__(self, tokenizer: AnyTokenizer): self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] # Override base class type - we use string IDs for tool calls - self.current_tool_id: Optional[str] = None # type: ignore + self.current_tool_id: str | None = None # type: ignore self.streamed_args_for_tool: list[str] = [] # Sentinel tokens for streaming mode @@ -110,7 +110,7 @@ def _reset_streaming_state(self): self.streaming_request = None def _get_arguments_config( - self, func_name: str, tools: Optional[list[ChatCompletionToolsParam]] + self, func_name: str, tools: list[ChatCompletionToolsParam] | None ) -> dict: """Extract argument configuration for a function.""" if tools is None: @@ -240,8 +240,8 @@ def _convert_param_value( return param_value def _parse_xml_function_call( - self, function_call_str: str, tools: Optional[list[ChatCompletionToolsParam]] - ) -> Optional[ToolCall]: + self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None + ) -> ToolCall | None: # Extract function name end_index = function_call_str.index(">") function_name = function_call_str[:end_index] @@ -349,7 +349,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: # Store request for type conversion if not previous_text: self._reset_streaming_state() diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py index 1b7e4fec316e..9964d1ac25c4 100644 --- a/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py @@ -2,13 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast import json -import uuid from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any from xml.parsers.expat import ParserCreate import regex as re +from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, ChatCompletionToolsParam, @@ -39,7 +39,7 @@ def __init__(self): self.reset_streaming_state() # Tool configuration information - self.tools: Union[list[ChatCompletionToolsParam], None] = None + self.tools: list[ChatCompletionToolsParam] | None = None self.tool_call_start_token: str = "<tool_call>" self.tool_call_end_token: str = "</tool_call>" self.function_start_token: str = "<function=" @@ -341,7 +341,7 @@ def _should_skip_element(self, element: str) -> bool: # Skip blank content return not element - def _find_next_complete_element(self, start_pos: int) -> tuple[Optional[str], int]: + def _find_next_complete_element(self, start_pos: int) -> tuple[str | None, int]: """ Find next complete XML element from specified position @@ -375,14 +375,21 @@ def _find_next_complete_element(self, start_pos: int) -> tuple[Optional[str], in return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1 else: # If currently not parsing tool calls (entering a tool_call), - # check if starts with <tool_call> + # check if starts with <tool_call> or <function= if self.current_call_id is None: # Check if might be start of <tool_call> if buffer == "<tool_call>"[: len(buffer)]: # Might be start of <tool_call>, wait for more data return None, start_pos + elif ( + buffer.startswith("<function=") + or buffer == "<function="[: len(buffer)] + ): + # Might be start of <function=, wait for more data + # to get the complete function tag + return None, start_pos else: - # Not start of <tool_call>, treat as text + # Not start of <tool_call> or <function=, treat as text return buffer, start_pos + len(buffer) else: # When parsing tool calls, @@ -584,7 +591,7 @@ def _emit_delta(self, delta: DeltaMessage): """Emit Delta response (streaming output)""" self.deltas.append(delta) - def _auto_close_open_parameter_if_needed(self, incoming_tag: Optional[str] = None): + def _auto_close_open_parameter_if_needed(self, incoming_tag: str | None = None): """Before starting to process new elements, if there are unclosed tags from before, automatically complete their endings to the parser. @@ -621,7 +628,7 @@ def _start_element(self, name: str, attrs: dict[str, str]): self._auto_close_open_parameter_if_needed("tool_call") self.parameters = {} - self.current_call_id = self._get_next_call_id() + self.current_call_id = make_tool_call_id() self.current_param_is_first = True self.tool_call_index += 1 elif name.startswith("function") or (name == "function"): @@ -953,15 +960,11 @@ def setup_parser(self): self.parser.EndElementHandler = self._end_element self.parser.CharacterDataHandler = self._char_data - def set_tools(self, tools: Union[list[ChatCompletionToolsParam], None]): + def set_tools(self, tools: list[ChatCompletionToolsParam] | None): """Set tool configuration information""" self.tools = tools - def _get_next_call_id(self): - """Generate unique call ID""" - return f"call_{uuid.uuid4().hex[:24]}" - - def _extract_function_name(self, name: str, attrs: dict[str, str]) -> Optional[str]: + def _extract_function_name(self, name: str, attrs: dict[str, str]) -> str | None: """Extract function name from various formats""" if attrs and "name" in attrs: return attrs["name"] @@ -973,9 +976,7 @@ def _extract_function_name(self, name: str, attrs: dict[str, str]) -> Optional[s return None - def _extract_parameter_name( - self, name: str, attrs: dict[str, str] - ) -> Optional[str]: + def _extract_parameter_name(self, name: str, attrs: dict[str, str]) -> str | None: """Extract parameter name from various formats""" if attrs and "name" in attrs: return attrs["name"] @@ -1170,6 +1171,10 @@ def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.parser = StreamingXMLToolCallParser() + # Add missing attributes for compatibility with serving_chat.py + self.prev_tool_call_arr: list[dict] = [] + self.streamed_args_for_tool: list[str] = [] + logger.info( "vLLM Successfully import tool parser %s !", self.__class__.__name__ ) @@ -1180,6 +1185,9 @@ def extract_tool_calls( request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: self.parser.reset_streaming_state() + # Reset tool call tracking arrays for new extraction + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [] if request: self.parser.set_tools(request.tools) result = self.parser.parse_single_streaming_chunks(model_output) @@ -1203,6 +1211,34 @@ def extract_tool_calls( ), ) ) + + # Update tool call tracking arrays for compatibility + tool_index = ( + tool_call.index + if tool_call.index is not None + else len(self.prev_tool_call_arr) - 1 + ) + + # Ensure we have enough entries in our tracking arrays + while len(self.prev_tool_call_arr) <= tool_index: + self.prev_tool_call_arr.append({"name": "", "arguments": ""}) + while len(self.streamed_args_for_tool) <= tool_index: + self.streamed_args_for_tool.append("") + + # Update tool call information + self.prev_tool_call_arr[tool_index]["name"] = ( + tool_call.function.name + ) + self.prev_tool_call_arr[tool_index]["arguments"] = ( + tool_call.function.arguments + ) + + # Update streamed arguments + if tool_call.function.arguments: + self.streamed_args_for_tool[tool_index] = ( + tool_call.function.arguments + ) + return ExtractedToolCallInformation( tool_calls=tool_calls, tools_called=len(tool_calls) > 0, @@ -1218,9 +1254,12 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: if not previous_text: self.parser.reset_streaming_state() + # Reset tool call tracking arrays for new streaming session + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [] if request: self.parser.set_tools(request.tools) @@ -1232,20 +1271,48 @@ def extract_tool_calls_streaming( open_calls = current_text.count( self.parser.tool_call_start_token ) - current_text.count(self.parser.tool_call_end_token) - if open_calls == 0 and self.parser.tool_call_index > 0: - # If current_call_id is None, use last_completed_call_id - call_id = ( - self.parser.current_call_id or self.parser.last_completed_call_id - ) - return DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.parser.tool_call_index - 1, - id=call_id, - function=DeltaFunctionCall(arguments=""), - type="function", + if ( + open_calls == 0 + and self.parser.tool_call_index > 0 + or not self.parser.tool_call_index + and current_text + ): + return DeltaMessage(content="") + return None + + # Parse the delta text and get the result + result = self.parser.parse_single_streaming_chunks(delta_text) + + # Update tool call tracking arrays based on incremental parsing results + if result and result.tool_calls: + for tool_call in result.tool_calls: + if tool_call.function: + tool_index = ( + tool_call.index + if tool_call.index is not None + else len(self.prev_tool_call_arr) - 1 + ) + + # Ensure we have enough entries in our tracking arrays + while len(self.prev_tool_call_arr) <= tool_index: + self.prev_tool_call_arr.append({"name": "", "arguments": ""}) + while len(self.streamed_args_for_tool) <= tool_index: + self.streamed_args_for_tool.append("") + + # Update tool name if provided + if tool_call.function.name: + self.prev_tool_call_arr[tool_index]["name"] = ( + tool_call.function.name ) - ] - ) - return self.parser.parse_single_streaming_chunks(delta_text) + # Update arguments incrementally + if tool_call.function.arguments is not None: + # Concatenate the incremental arguments + # to the existing streamed arguments + self.prev_tool_call_arr[tool_index]["arguments"] += ( + tool_call.function.arguments + ) + self.streamed_args_for_tool[tool_index] += ( + tool_call.function.arguments + ) + return result diff --git a/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py index 2e7bd0d1d344..f50a2df53bc0 100644 --- a/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py @@ -7,7 +7,7 @@ import json import uuid from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any import regex as re @@ -109,8 +109,8 @@ def _reset_streaming_state(self): self.json_closed = False def _parse_xml_function_call( - self, function_call_str: str, tools: Optional[list[ChatCompletionToolsParam]] - ) -> Optional[ToolCall]: + self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None + ) -> ToolCall | None: def get_arguments_config(func_name: str) -> dict: if tools is None: return {} @@ -357,7 +357,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: # If no delta text, return None unless # it's an EOS token after tool calls if not delta_text: diff --git a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py index 34bd372b2060..d0255ec08539 100644 --- a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py @@ -4,7 +4,7 @@ import contextlib import json from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any import regex as re @@ -51,6 +51,7 @@ def __init__(self, tokenizer: AnyTokenizer): self.tool_block_finished = False def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + request = super().adjust_request(request) if request.tools and request.tool_choice != "none": request.skip_special_tokens = False return request @@ -58,7 +59,7 @@ def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionReques @staticmethod def _parse_steptml_invoke( action_text: str, - ) -> tuple[Optional[str], Optional[dict[str, str]]]: + ) -> tuple[str | None, dict[str, str] | None]: func_name_match = re.search(r'<steptml:invoke name="([^"]+)">', action_text) if not func_name_match: return None, None @@ -117,7 +118,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: # The main loop processes the stream from the last known position. while True: if self.position >= len(current_text): diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/entrypoints/openai/tool_parsers/utils.py index e076ab38e336..570eb447a467 100644 --- a/vllm/entrypoints/openai/tool_parsers/utils.py +++ b/vllm/entrypoints/openai/tool_parsers/utils.py @@ -6,8 +6,18 @@ from typing import Any import partial_json_parser +from openai.types.responses import ( + FunctionTool, + ToolChoiceFunction, +) +from openai.types.responses.tool import Tool from partial_json_parser.core.options import Allow +from vllm.entrypoints.openai.protocol import ( + ChatCompletionNamedToolChoiceParam, + ChatCompletionToolsParam, +) + def find_common_prefix(s1: str, s2: str) -> str: """ @@ -122,3 +132,98 @@ def consume_space(i: int, s: str) -> int: while i < len(s) and s[i].isspace(): i += 1 return i + + +def _extract_tool_info( + tool: Tool | ChatCompletionToolsParam, +) -> tuple[str, dict[str, Any] | None]: + if isinstance(tool, FunctionTool): + return tool.name, tool.parameters + elif isinstance(tool, ChatCompletionToolsParam): + return tool.function.name, tool.function.parameters + else: + raise TypeError(f"Unsupported tool type: {type(tool)}") + + +def _get_tool_schema_from_tool(tool: Tool | ChatCompletionToolsParam) -> dict: + name, params = _extract_tool_info(tool) + params = params if params else {"type": "object", "properties": {}} + return { + "properties": { + "name": {"type": "string", "enum": [name]}, + "parameters": params, + }, + "required": ["name", "parameters"], + } + + +def _get_tool_schema_defs( + tools: list[Tool | ChatCompletionToolsParam], +) -> dict: + all_defs: dict[str, dict[str, Any]] = {} + for tool in tools: + _, params = _extract_tool_info(tool) + if params is None: + continue + defs = params.pop("$defs", {}) + for def_name, def_schema in defs.items(): + if def_name in all_defs and all_defs[def_name] != def_schema: + raise ValueError( + f"Tool definition '{def_name}' has multiple schemas, " + "which is not supported." + ) + all_defs[def_name] = def_schema + return all_defs + + +def _get_json_schema_from_tools( + tools: list[Tool | ChatCompletionToolsParam], +) -> dict: + json_schema = { + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "anyOf": [_get_tool_schema_from_tool(tool) for tool in tools], + }, + } + json_schema_defs = _get_tool_schema_defs(tools) + if json_schema_defs: + json_schema["$defs"] = json_schema_defs + return json_schema + + +def get_json_schema_from_tools( + tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam, + tools: list[FunctionTool | ChatCompletionToolsParam] | None, +) -> str | dict | None: + # tool_choice: "none" + if tool_choice in ("none", None) or tools is None: + return None + # tool_choice: Forced Function (Responses) + if (not isinstance(tool_choice, str)) and isinstance( + tool_choice, ToolChoiceFunction + ): + tool_name = tool_choice.name + tool_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)} + if tool_name not in tool_map: + raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.") + return tool_map[tool_name].parameters + # tool_choice: Forced Function (ChatCompletion) + if (not isinstance(tool_choice, str)) and isinstance( + tool_choice, ChatCompletionNamedToolChoiceParam + ): + tool_name = tool_choice.function.name + tool_map = { + tool.function.name: tool + for tool in tools + if isinstance(tool, ChatCompletionToolsParam) + } + if tool_name not in tool_map: + raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.") + return tool_map[tool_name].function.parameters + # tool_choice: "required" + if tool_choice == "required": + return _get_json_schema_from_tools(tools) + # tool_choice: "auto" + return None diff --git a/vllm/entrypoints/renderer.py b/vllm/entrypoints/renderer.py index 98c9cbbbd376..3c5a396a99f9 100644 --- a/vllm/entrypoints/renderer.py +++ b/vllm/entrypoints/renderer.py @@ -5,7 +5,7 @@ import io from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Annotated, Optional, Union +from typing import Annotated import pybase64 import torch @@ -17,32 +17,32 @@ from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.parse import get_prompt_components, parse_raw_prompts from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import AsyncMicrobatchTokenizer +from vllm.utils.async_utils import AsyncMicrobatchTokenizer @dataclass(frozen=True) class RenderConfig: """Configuration to control how prompts are prepared.""" - max_length: Optional[int] = None + max_length: int | None = None """Maximum allowable total input token length. If provided, - token inputs longer than this raise ``ValueError``.""" + token inputs longer than this raise `ValueError`.""" - truncate_prompt_tokens: Optional[int] = None - """Number of tokens to keep. ``None`` means no truncation. - ``0`` yields an empty list (and skips embeds). - ``-1`` maps to ``model_config.max_model_len``.""" + truncate_prompt_tokens: int | None = None + """Number of tokens to keep. `None` means no truncation. + `0` yields an empty list (and skips embeds). + `-1` maps to `model_config.max_model_len`.""" - add_special_tokens: Optional[bool] = True + add_special_tokens: bool | None = True """Whether to add model-specific special tokens during tokenization.""" - cache_salt: Optional[str] = None + cache_salt: str | None = None """String to disambiguate prefix cache entries.""" - needs_detokenization: Optional[bool] = False + needs_detokenization: bool | None = False """If True, detokenize IDs back to text for inclusion in outputs.""" - def verify_truncate_prompt_tokens(self, model_config: ModelConfig) -> Optional[int]: + def verify_truncate_prompt_tokens(self, model_config: ModelConfig) -> int | None: """Validate and normalize `truncate_prompt_tokens` parameter.""" truncate_prompt_tokens = self.truncate_prompt_tokens if truncate_prompt_tokens is None: @@ -85,7 +85,7 @@ class BaseRenderer(ABC): def __init__( self, model_config: ModelConfig, - tokenizer: Optional[AnyTokenizer] = None, + tokenizer: AnyTokenizer | None = None, ): super().__init__() self.model_config = model_config @@ -95,7 +95,7 @@ def __init__( async def render_prompt( self, *, - prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], + prompt_or_prompts: str | list[str] | list[int] | list[list[int]], config: RenderConfig, ) -> list[EngineTokensPrompt]: """ @@ -107,10 +107,10 @@ async def render_prompt( Args: prompt_or_prompts: One of: - - ``str``: Single text prompt. - - ``list[str]``: Batch of text prompts. - - ``list[int]``: Single pre-tokenized sequence. - - ``list[list[int]]``: Batch of pre-tokenized sequences. + - `str`: Single text prompt. + - `list[str]`: Batch of text prompts. + - `list[int]`: Single pre-tokenized sequence. + - `list[list[int]]`: Batch of pre-tokenized sequences. config: Render configuration controlling how prompts are prepared (e.g., tokenization and length handling). @@ -126,19 +126,17 @@ async def render_prompt( async def render_prompt_and_embeds( self, *, - prompt_or_prompts: Optional[ - Union[str, list[str], list[int], list[list[int]]] - ] = None, - prompt_embeds: Optional[Union[bytes, list[bytes]]] = None, + prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None, + prompt_embeds: bytes | list[bytes] | None = None, config: RenderConfig, - ) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: + ) -> list[EngineTokensPrompt | EngineEmbedsPrompt]: """ Convert text/token and/or base64-encoded embeddings inputs into engine-ready prompt objects using a unified RenderConfig. - At least one of ``prompt_or_prompts`` or ``prompt_embeds`` must be + At least one of `prompt_or_prompts` or `prompt_embeds` must be provided and non-empty. If both are omitted or empty (e.g., empty - string and empty list), a ``ValueError`` is raised. + string and empty list), a `ValueError` is raised. Args: prompt_or_prompts: Text or token inputs to include. @@ -152,20 +150,23 @@ async def render_prompt_and_embeds( Engine-ready prompt objects. Raises: - ValueError: If both ``prompt_or_prompts`` and ``prompt_embeds`` + ValueError: If both `prompt_or_prompts` and `prompt_embeds` are omitted or empty (decoder prompt cannot be empty), or if length limits are exceeded. """ raise NotImplementedError - @classmethod def load_prompt_embeds( - cls, - prompt_embeds: Union[bytes, list[bytes]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=0)]] = None, - cache_salt: Optional[str] = None, + self, + prompt_embeds: bytes | list[bytes], + truncate_prompt_tokens: Annotated[int, Field(ge=0)] | None = None, + cache_salt: str | None = None, ) -> list[EngineEmbedsPrompt]: """Load and validate base64-encoded embeddings into prompt objects.""" + if not self.model_config.enable_prompt_embeds: + raise ValueError( + "You must set `--enable-prompt-embeds` to input `prompt_embeds`." + ) def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt: tensor = torch.load( @@ -199,19 +200,18 @@ class CompletionRenderer(BaseRenderer): def __init__( self, model_config: ModelConfig, - tokenizer: Optional[AnyTokenizer] = None, - async_tokenizer_pool: Optional[ - dict[AnyTokenizer, AsyncMicrobatchTokenizer] - ] = None, + tokenizer: AnyTokenizer | None = None, + async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] + | None = None, ): super().__init__(model_config, tokenizer) self.async_tokenizer_pool = async_tokenizer_pool - self.async_tokenizer: Optional[AsyncMicrobatchTokenizer] = None + self.async_tokenizer: AsyncMicrobatchTokenizer | None = None async def render_prompt( self, *, - prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], + prompt_or_prompts: str | list[str] | list[int] | list[list[int]], config: RenderConfig, ) -> list[EngineTokensPrompt]: """Implementation of prompt rendering for completion-style requests. @@ -237,12 +237,10 @@ async def render_prompt( async def render_prompt_and_embeds( self, *, - prompt_or_prompts: Optional[ - Union[str, list[str], list[int], list[list[int]]] - ] = None, - prompt_embeds: Optional[Union[bytes, list[bytes]]] = None, + prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None, + prompt_embeds: bytes | list[bytes] | None = None, config: RenderConfig, - ) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: + ) -> list[EngineTokensPrompt | EngineEmbedsPrompt]: """ Render text/token prompts and/or precomputed embedding prompts. At least one of `prompt_or_prompts` or `prompt_embeds` must be provided. @@ -251,7 +249,7 @@ async def render_prompt_and_embeds( if truncate_prompt_tokens == 0: return [] - rendered: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]] = [] + rendered: list[EngineTokensPrompt | EngineEmbedsPrompt] = [] if prompt_embeds is not None: rendered.extend( @@ -271,7 +269,7 @@ async def render_prompt_and_embeds( return rendered def _maybe_apply_truncation( - self, token_ids: list[int], truncate_prompt_tokens: Optional[int] + self, token_ids: list[int], truncate_prompt_tokens: int | None ) -> list[int]: """Apply truncation to token sequence.""" if truncate_prompt_tokens is None: @@ -283,9 +281,9 @@ def _maybe_apply_truncation( async def _create_prompt( self, - prompt_input: Union[EngineTextPrompt, EngineTokensPrompt], + prompt_input: EngineTextPrompt | EngineTokensPrompt, config: RenderConfig, - truncate_prompt_tokens: Optional[int], + truncate_prompt_tokens: int | None, ) -> EngineTokensPrompt: prompt, prompt_token_ids, _ = get_prompt_components(prompt_input) @@ -315,10 +313,10 @@ async def _create_prompt( async def _create_prompt_from_text( self, text: str, - max_length: Optional[int], - truncate_prompt_tokens: Optional[int], - add_special_tokens: Optional[bool], - cache_salt: Optional[str], + max_length: int | None, + truncate_prompt_tokens: int | None, + add_special_tokens: bool | None, + cache_salt: str | None, ) -> EngineTokensPrompt: """Tokenize text input asynchronously.""" async_tokenizer = self._get_async_tokenizer() @@ -348,10 +346,10 @@ async def _create_prompt_from_text( async def _create_prompt_from_token_ids( self, token_ids: list[int], - max_length: Optional[int], - truncate_prompt_tokens: Optional[int], - cache_salt: Optional[str], - needs_detokenization: Optional[bool] = False, + max_length: int | None, + truncate_prompt_tokens: int | None, + cache_salt: str | None, + needs_detokenization: bool | None = False, ) -> EngineTokensPrompt: """Optionally detokenize token IDs and build a tokens prompt.""" token_ids = self._maybe_apply_truncation(token_ids, truncate_prompt_tokens) @@ -391,9 +389,9 @@ def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer: def _create_tokens_prompt( self, token_ids: list[int], - max_length: Optional[int] = None, - cache_salt: Optional[str] = None, - prompt: Optional[str] = None, + max_length: int | None = None, + cache_salt: str | None = None, + prompt: str | None = None, ) -> EngineTokensPrompt: """Create validated EngineTokensPrompt.""" if max_length is not None and len(token_ids) > max_length: diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index 1fb56d246deb..309a4c996392 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional, Union, cast +from typing import Any, TypeAlias, cast from torch.nn import CosineSimilarity -from typing_extensions import Required, TypeAlias, TypedDict +from typing_extensions import Required, TypedDict from vllm.config import ModelConfig from vllm.entrypoints.chat_utils import ( @@ -25,9 +25,9 @@ PreTrainedTokenizerFast, ) -ScoreContentPartParam: TypeAlias = Union[ - ChatCompletionContentPartImageParam, ChatCompletionContentPartImageEmbedsParam -] +ScoreContentPartParam: TypeAlias = ( + ChatCompletionContentPartImageParam | ChatCompletionContentPartImageEmbedsParam +) class ScoreMultiModalParam(TypedDict, total=False): @@ -45,12 +45,12 @@ class ScoreMultiModalParam(TypedDict, total=False): def _cosine_similarity( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, embed_1: list[PoolingRequestOutput], embed_2: list[PoolingRequestOutput], ) -> list[PoolingRequestOutput]: scorer = CosineSimilarity(0) - scores: Union[list[PoolingRequestOutput]] = [] + scores: list[PoolingRequestOutput] = [] for emb_1, emb_2 in zip(embed_1, embed_2): pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data) @@ -66,6 +66,7 @@ def _cosine_similarity( request_id=f"{emb_1.request_id}_{emb_2.request_id}", outputs=pair_score, prompt_token_ids=tokens, + num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens, finished=True, ) ) @@ -74,8 +75,8 @@ def _cosine_similarity( def _validate_score_input_lens( - data_1: Union[list[str], list[ScoreContentPartParam]], - data_2: Union[list[str], list[ScoreContentPartParam]], + data_1: list[str] | list[ScoreContentPartParam], + data_2: list[str] | list[ScoreContentPartParam], ): len_1 = len(data_1) len_2 = len(data_2) @@ -89,18 +90,18 @@ def _validate_score_input_lens( def parse_score_data( - data_1: Union[str, ScoreContentPartParam], - data_2: Union[str, ScoreContentPartParam], + data_1: str | ScoreContentPartParam, + data_2: str | ScoreContentPartParam, model_config: ModelConfig, tokenizer: AnyTokenizer, -) -> tuple[str, str, Optional[MultiModalDataDict]]: +) -> tuple[str, str, MultiModalDataDict | None]: mm_tracker = MultiModalItemTracker(model_config, tokenizer) content_1 = _parse_score_content(data_1, mm_tracker) content_2 = _parse_score_content(data_2, mm_tracker) - def ensure_str(content: Optional[_ContentPart]) -> str: + def ensure_str(content: _ContentPart | None) -> str: if content is not None and isinstance(content, str): return cast(str, content) else: @@ -113,9 +114,9 @@ def ensure_str(content: Optional[_ContentPart]) -> str: def _parse_score_content( - data: Union[str, ScoreContentPartParam], + data: str | ScoreContentPartParam, mm_tracker: BaseMultiModalItemTracker, -) -> Optional[_ContentPart]: +) -> _ContentPart | None: if isinstance(data, str): data = ChatCompletionContentPartTextParam(type="text", text=data) @@ -182,8 +183,8 @@ def get_score_prompt( model_config: ModelConfig, tokenizer: AnyTokenizer, tokenization_kwargs: dict[str, Any], - data_1: Union[str, ScoreContentPartParam], - data_2: Union[str, ScoreContentPartParam], + data_1: str | ScoreContentPartParam, + data_2: str | ScoreContentPartParam, ) -> tuple[str, TokensPrompt]: prompt_1, prompt_2, mm_data = parse_score_data( data_1, diff --git a/vllm/entrypoints/ssl.py b/vllm/entrypoints/ssl.py index ff0dd1bbfc6b..4d947bc620cf 100644 --- a/vllm/entrypoints/ssl.py +++ b/vllm/entrypoints/ssl.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +from collections.abc import Callable from ssl import SSLContext -from typing import Callable, Optional from watchfiles import Change, awatch @@ -20,9 +20,9 @@ class SSLCertRefresher: def __init__( self, ssl_context: SSLContext, - key_path: Optional[str] = None, - cert_path: Optional[str] = None, - ca_path: Optional[str] = None, + key_path: str | None = None, + cert_path: str | None = None, + ca_path: str | None = None, ) -> None: self.ssl = ssl_context self.key_path = key_path diff --git a/vllm/entrypoints/tool_server.py b/vllm/entrypoints/tool_server.py index b3dceecc1583..0d83031ef69f 100644 --- a/vllm/entrypoints/tool_server.py +++ b/vllm/entrypoints/tool_server.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod from contextlib import AbstractAsyncContextManager, asynccontextmanager -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from openai_harmony import ToolDescription, ToolNamespaceConfig @@ -80,7 +80,7 @@ def has_tool(self, tool_name: str) -> bool: pass @abstractmethod - def get_tool_description(self, tool_name: str) -> Optional[ToolNamespaceConfig]: + def get_tool_description(self, tool_name: str) -> ToolNamespaceConfig | None: """ Return the tool description for the given tool name. If the tool is not supported, return None. @@ -89,7 +89,7 @@ def get_tool_description(self, tool_name: str) -> Optional[ToolNamespaceConfig]: @abstractmethod def new_session( - self, tool_name: str, session_id: str, headers: Optional[dict[str, str]] = None + self, tool_name: str, session_id: str, headers: dict[str, str] | None = None ) -> AbstractAsyncContextManager[Any]: """ Create a session for the tool. @@ -152,7 +152,7 @@ def get_tool_description(self, tool_name: str): @asynccontextmanager async def new_session( - self, tool_name: str, session_id: str, headers: Optional[dict[str, str]] = None + self, tool_name: str, session_id: str, headers: dict[str, str] | None = None ): from mcp import ClientSession from mcp.client.sse import sse_client @@ -190,7 +190,7 @@ async def init_and_validate(self): def has_tool(self, tool_name: str) -> bool: return tool_name in self.tools - def get_tool_description(self, tool_name: str) -> Optional[ToolNamespaceConfig]: + def get_tool_description(self, tool_name: str) -> ToolNamespaceConfig | None: if tool_name not in self.tools: return None if tool_name == "browser": @@ -202,7 +202,7 @@ def get_tool_description(self, tool_name: str) -> Optional[ToolNamespaceConfig]: @asynccontextmanager async def new_session( - self, tool_name: str, session_id: str, headers: Optional[dict[str, str]] = None + self, tool_name: str, session_id: str, headers: dict[str, str] | None = None ): if tool_name not in self.tools: raise KeyError(f"Tool '{tool_name}' is not supported") diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index c97ca6538814..ec5fb3b56b7f 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -6,17 +6,31 @@ import functools import os from argparse import Namespace -from typing import Any, Optional, Union +from pathlib import Path +from typing import Any from fastapi import Request from fastapi.responses import JSONResponse, StreamingResponse from starlette.background import BackgroundTask, BackgroundTasks +from vllm.config import ModelConfig from vllm.engine.arg_utils import EngineArgs +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import ( + load_chat_template, + resolve_hf_chat_template, + resolve_mistral_chat_template, +) from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + CompletionRequest, + StreamOptions, +) +from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.transformers_utils.tokenizers import MistralTokenizer from vllm.utils import FlexibleArgumentParser logger = init_logger(__name__) @@ -164,9 +178,9 @@ def cli_env_setup(): def _validate_truncation_size( max_model_len: int, - truncate_prompt_tokens: Optional[int], - tokenization_kwargs: Optional[dict[str, Any]] = None, -) -> Optional[int]: + truncate_prompt_tokens: int | None, + tokenization_kwargs: dict[str, Any] | None = None, +) -> int | None: if truncate_prompt_tokens is not None: if truncate_prompt_tokens <= -1: truncate_prompt_tokens = max_model_len @@ -191,7 +205,7 @@ def _validate_truncation_size( def get_max_tokens( max_model_len: int, - request: Union[ChatCompletionRequest, CompletionRequest], + request: ChatCompletionRequest | CompletionRequest, input_length: int, default_sampling_params: dict, ) -> int: @@ -211,7 +225,7 @@ def get_max_tokens( ) -def log_non_default_args(args: Union[Namespace, EngineArgs]): +def log_non_default_args(args: Namespace | EngineArgs): non_default_args = {} # Handle Namespace @@ -237,3 +251,69 @@ def log_non_default_args(args: Union[Namespace, EngineArgs]): ) logger.info("non-default args: %s", non_default_args) + + +def should_include_usage( + stream_options: StreamOptions | None, enable_force_include_usage: bool +) -> tuple[bool, bool]: + if stream_options: + include_usage = stream_options.include_usage or enable_force_include_usage + include_continuous_usage = include_usage and bool( + stream_options.continuous_usage_stats + ) + else: + include_usage, include_continuous_usage = enable_force_include_usage, False + return include_usage, include_continuous_usage + + +def process_lora_modules( + args_lora_modules: list[LoRAModulePath], default_mm_loras: dict[str, str] | None +) -> list[LoRAModulePath]: + lora_modules = args_lora_modules + if default_mm_loras: + default_mm_lora_paths = [ + LoRAModulePath( + name=modality, + path=lora_path, + ) + for modality, lora_path in default_mm_loras.items() + ] + if args_lora_modules is None: + lora_modules = default_mm_lora_paths + else: + lora_modules += default_mm_lora_paths + return lora_modules + + +async def process_chat_template( + args_chat_template: Path | str | None, + engine_client: EngineClient, + model_config: ModelConfig, +) -> str | None: + resolved_chat_template = load_chat_template(args_chat_template) + if resolved_chat_template is not None: + # Get the tokenizer to check official template + tokenizer = await engine_client.get_tokenizer() + + if isinstance(tokenizer, MistralTokenizer): + # The warning is logged in resolve_mistral_chat_template. + resolved_chat_template = resolve_mistral_chat_template( + chat_template=resolved_chat_template + ) + else: + hf_chat_template = resolve_hf_chat_template( + tokenizer=tokenizer, + chat_template=None, + tools=None, + model_config=model_config, + ) + + if hf_chat_template != resolved_chat_template: + logger.warning( + "Using supplied chat template: %s\n" + "It is different from official chat template '%s'. " + "This discrepancy may lead to performance degradation.", + resolved_chat_template, + model_config.model, + ) + return resolved_chat_template diff --git a/vllm/env_override.py b/vllm/env_override.py index 7f9054e73846..ae3e4e751bd9 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -5,6 +5,7 @@ import torch from vllm.logger import init_logger +from vllm.utils.torch_utils import is_torch_equal logger = init_logger(__name__) @@ -21,3 +22,339 @@ os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" # see https://github.com/vllm-project/vllm/issues/10619 torch._inductor.config.compile_threads = 1 + +# =================================================== +# torch 2.9 Inductor PythonWrapperCodegen monkeypatch +# =================================================== +# This change monkeypatches memory_plan_reuse in pytorch 2.9.0 to work around +# a test failure for test_multi_graph_piecewise_compile_outputs_equal. +# For more context, see https://github.com/pytorch/pytorch/pull/165514. + + +def memory_plan_reuse_patched(self): + import torch._inductor.ir as ir + from torch._inductor.codegen.wrapper import ( + EnterSubgraphLine, + ExitSubgraphLine, + MemoryPlanningLine, + MemoryPlanningState, + SubgraphPythonWrapperCodegen, + ) + from torch._inductor.virtualized import V + + def get_output_names(graph_outputs) -> list[str]: + import itertools + + names = [] + shape_counter = itertools.count(0) + none_counter = itertools.count(0) + for node in graph_outputs: + if isinstance(node, ir.NoneAsConstantBuffer): + names.append(f"{V.graph.name}_none{next(none_counter)}") + elif isinstance(node, ir.ShapeAsConstantBuffer): + names.append(f"{V.graph.name}_shape{next(shape_counter)}") + else: + names.append(node.get_name()) + return names + + if ( + isinstance(V.graph.wrapper_code, SubgraphPythonWrapperCodegen) + and V.graph.wrapper_code.partition_signatures is not None + ): + out_names = get_output_names( + V.graph.wrapper_code.partition_signatures.output_nodes + ) + else: + out_names = V.graph.get_output_names() + + while ( + self.lines + and isinstance(self.lines[-1], MemoryPlanningLine) + and self.lines[-1].node.name not in out_names # type: ignore[attr-defined] + ): + # these lines will be pointless + self.lines.pop() + + # codegen allocations in two passes + planning_states = [MemoryPlanningState()] + past_planning_states = [] + for i in range(len(self.lines)): + line = self.lines[i] + if isinstance(line, MemoryPlanningLine): + self.lines[i] = line.plan(planning_states[-1]) + elif isinstance(line, EnterSubgraphLine): + planning_states.append(MemoryPlanningState()) + elif isinstance(line, ExitSubgraphLine): + past_planning_states.append(planning_states.pop()) + past_planning_states.append(planning_states.pop()) + assert len(planning_states) == 0 + + +# =================================================== +# torch 2.9 Inductor get_graph_partition_signature monkeypatch +# =================================================== +# This change monkeypatches get_graph_partition_signature in pytorch 2.9.0 to +# fix inductor partition + attention-nvfp4 quant fusion, tested in +# `tests/compile/test_fusions_e2e.py::test_attn_quant`. +# For more context, see https://github.com/pytorch/pytorch/pull/165815. + + +def get_graph_partition_signature_patched( + self, partitions, skip_cudagraphs: list[bool] +): + """ + Gets signature for each graph partition, including input nodes, output nodes, and + whether deallocating an input within graph partition. + """ + from torch._inductor import dependencies + from torch._inductor.ir import GraphPartitionSignature, MutationOutput, NoneLayout + from torch._inductor.virtualized import V + from torch.utils._ordered_set import OrderedSet + + signatures = [] + + unmet_output_names = OrderedSet(V.graph.get_output_names()) + name_to_node = self.get_name_to_nodes() + + def is_none_layout(buf_name: str) -> bool: + """ + Checks if buf_name is NoneLayout. Buffers with NoneLayout is not allocated + so graph partition should not take it as inputs or outputs. + """ + buf = self.name_to_buf.get(buf_name, None) + + if buf is None: + return False + + if isinstance(buf.node.layout, NoneLayout): + if isinstance(buf.node, MutationOutput) and ( + real_name := self.mutation_real_name.get(buf_name, None) + ): + return is_none_layout(real_name) + + return True + + return False + + for partition, skip_cudagraph in zip( + reversed(partitions), reversed(skip_cudagraphs) + ): + output_names: OrderedSet[str] = OrderedSet() + + for node in partition: + output_names.update(node.outputs_by_name.keys()) + + returned_output_names = output_names.intersection(unmet_output_names) + + # all reads/writes are partition inputs except those generated + # within the partition and tensor constants + read_writes = dependencies.ReadWrites.merge_list( + [node.read_writes for node in partition] + ) + + # WeakDep is fake dependency on unused buffer. It should not appear + # in partition_input_names for inputs that are actually read or written. + partition_input_names = ( + OrderedSet( + [ + x.name + for x in read_writes.reads | read_writes.writes + if not is_none_layout(x.name) + ] + ) + - output_names + ) + + partition_input_names = OrderedSet( + self.mutation_real_name.get(name, name) for name in partition_input_names + ) + + buffer_names_to_free: OrderedSet[str] = OrderedSet() + for node in partition: + buffer_names_to_free.update(node.last_usage) + + # buffer_names_to_free may contain buffers allocated in previous + # graph partitions. These buffers should also be a partition + # input. + extra_input_names = [ + name + for name in (buffer_names_to_free - output_names) + if name in name_to_node + ] + partition_input_names.update(extra_input_names) + + input_nodes = { + name: name_to_node[name] + for name in partition_input_names + if name in name_to_node + } + input_deallocation = { + name: name in buffer_names_to_free + for name in partition_input_names + if name in name_to_node + } + + # if an input tensor is not freed in the partition function, it should + # also be returned as an output. This brings benefits to cudagraph + # since the returned output tensor is a cudagraph managed tensor with + # a static tensor address. + extra_output_names = [ + name + for name in partition_input_names + if name in name_to_node and name not in buffer_names_to_free + ] + + returned_output_names.update(extra_output_names) + + returned_output_names = OrderedSet( + self.mutation_real_name.get(name, name) for name in returned_output_names + ) + + output_nodes = [ + name_to_node[name] + for name in returned_output_names + if not is_none_layout(name) + ] + + constant_names = [ + name for name in partition_input_names if name in V.graph.constants + ] + + symbol_inputs = self.get_graph_partition_symbol_inputs(partition, input_nodes) + + partition_signature = GraphPartitionSignature( + symbol_inputs, + input_nodes, + output_nodes, + input_deallocation, + skip_cudagraph, + constant_names, + ) + + signatures.append(partition_signature) + + unmet_output_names = partition_input_names.union( + unmet_output_names - returned_output_names + ) + + return signatures[::-1] + + +# ======================================== +# torch 2.9 Inductor Scheduler monkeypatch +# ======================================== +# This change monkeypatches a function in Inductor to work around the following +# bug: https://github.com/vllm-project/vllm/issues/26678 +# +# The bug occurs when `use_inductor_graph_partition` is turned on and there +# exists operators inside of `splitting_ops` that have an in-place mutation. In +# vllm, this specifically occurs on the operator +# vllm.unified_attention_with_output. In this case, inductor does not populate +# the inductor IR's `origin_node` field, causing an assertion error when trying +# to access the node's `origin_node` field. +# +# So, we will monkeypatch torch._inductor.scheduler.Scheduler.should_partition +# so that it does not access the inductor IR node's `origin_node` field and just +# returns True if a node is registered as having a custom partition function. +# This is ok for now since vllm's implementation of the custom partition +# functions just return True. +# ======================================== + + +def should_partition_patched(self, node, should_log: bool = False) -> bool: + # This is a patched version of + # torch._inductor.scheduler.Scheduler.should_partition that modifies + # the following piece of code so that we always return True: + # https://github.com/pytorch/pytorch/blob/ecb53078faf86ca1b33277df33b82985675bb011/torch/_inductor/scheduler.py#L4712-L4724 + """Return True if we should partition the inductor graph on this node""" + + import torch._inductor.ir as ir + from torch._inductor.scheduler import ( + BaseSchedulerNode, + FusedSchedulerNode, + _custom_should_partition_fns, + ) + from torch._inductor.utils import ( + _unstable_customized_partition_wrapper, + is_cudagraph_unsafe_op, + maybe_log_cudagraph_partition, + ) + + # Allow users to manually specify if a node should be partitioned + # Can only do this for FallbackKernels + ir_node = node.node + if isinstance(ir_node, ir.FallbackKernel): + operator = ir_node.op_overload + if operator is not None and operator in _custom_should_partition_fns: + return True + + # When not using cudagraphs, keep all kernels in the `call` function + # instead of graph partition functions, since graph partition only brings + # benefit to cudagraph + if ( + not torch._inductor.config.triton.cudagraphs + and _unstable_customized_partition_wrapper.wrapper is None + ): + return True + + # avoid duplicating logs when should_partition is called multiple times + # on the same node + def noop_log(msg: str, node: BaseSchedulerNode | None) -> None: + return + + log_partition_reason = maybe_log_cudagraph_partition if should_log else noop_log + + if isinstance(node, FusedSchedulerNode): + return any(self.should_partition(snode) for snode in node.snodes) + + assert node.node is not None + + if not node.is_gpu(): + log_partition_reason("non gpu ops", node=node) + + return True + + if isinstance(node.node, ir.DeviceCopy): + log_partition_reason("DeviceCopy ops", node=node) + return True + + if isinstance(node.node, ir.Conditional): + log_partition_reason("Conditional ops", node=node) + return True + + if getattr(node.node, "unbacked_bindings", None): + log_partition_reason("unbacked binding ops", node=node) + return True + + if is_cudagraph_unsafe_op(node.node): + log_partition_reason("CUDAGraph-unsafe custom ops", node=node) + return True + + return False + + +def _update_scheduler_patched(self) -> None: + # Copied from torch._inductor.graph.GrahLowering._update_scheduler. Patches + # this method so that we can patch Scheduler.should_partition with the + # function above + """ + (Re)initializes the scheduler member. When initializing the scheduler, no CUBIN + files should be generated (to avoid biasing any benchmarks and pessimizing + fusion decisions). + """ + import torch._inductor.config as config + from torch._inductor.scheduler import Scheduler + + Scheduler.should_partition = should_partition_patched + Scheduler.get_graph_partition_signature = get_graph_partition_signature_patched + + with config.patch("triton.store_cubin", False): + self.scheduler = Scheduler(self.operations) + + +if is_torch_equal("2.9.0"): + from torch._inductor.codegen.wrapper import PythonWrapperCodegen + from torch._inductor.graph import GraphLowering + + PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched + GraphLowering._update_scheduler = _update_scheduler_patched diff --git a/vllm/envs.py b/vllm/envs.py index 01e93f224a30..0c45f93ec057 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1,32 +1,35 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools import hashlib import json import os import sys import tempfile -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Literal if TYPE_CHECKING: VLLM_HOST_IP: str = "" - VLLM_PORT: Optional[int] = None + VLLM_PORT: int | None = None VLLM_RPC_BASE_PATH: str = tempfile.gettempdir() VLLM_USE_MODELSCOPE: bool = False VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 - VLLM_NCCL_SO_PATH: Optional[str] = None - LD_LIBRARY_PATH: Optional[str] = None + VLLM_NCCL_SO_PATH: str | None = None + LD_LIBRARY_PATH: str | None = None VLLM_USE_TRITON_FLASH_ATTN: bool = True VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False - VLLM_FLASH_ATTN_VERSION: Optional[int] = None + VLLM_FLASH_ATTN_VERSION: int | None = None LOCAL_RANK: int = 0 - CUDA_VISIBLE_DEVICES: Optional[str] = None + CUDA_VISIBLE_DEVICES: str | None = None VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60 - VLLM_API_KEY: Optional[str] = None - S3_ACCESS_KEY_ID: Optional[str] = None - S3_SECRET_ACCESS_KEY: Optional[str] = None - S3_ENDPOINT_URL: Optional[str] = None - VLLM_MODEL_REDIRECT_PATH: Optional[str] = None + VLLM_API_KEY: str | None = None + VLLM_DEBUG_LOG_API_SERVER_RESPONSE: bool = False + S3_ACCESS_KEY_ID: str | None = None + S3_SECRET_ACCESS_KEY: str | None = None + S3_ENDPOINT_URL: str | None = None + VLLM_MODEL_REDIRECT_PATH: str | None = None VLLM_CACHE_ROOT: str = os.path.expanduser("~/.cache/vllm") VLLM_CONFIG_ROOT: str = os.path.expanduser("~/.config/vllm") VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai" @@ -38,24 +41,21 @@ VLLM_LOGGING_LEVEL: str = "INFO" VLLM_LOGGING_PREFIX: str = "" VLLM_LOGGING_STREAM: str = "ext://sys.stdout" - VLLM_LOGGING_CONFIG_PATH: Optional[str] = None - VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None + VLLM_LOGGING_CONFIG_PATH: str | None = None VLLM_LOG_STATS_INTERVAL: float = 10.0 VLLM_TRACE_FUNCTION: int = 0 - VLLM_ATTENTION_BACKEND: Optional[str] = None - VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None - VLLM_PP_LAYER_PARTITION: Optional[str] = None - VLLM_CPU_KVCACHE_SPACE: Optional[int] = 0 + VLLM_ATTENTION_BACKEND: str | None = None + VLLM_USE_FLASHINFER_SAMPLER: bool | None = None + VLLM_PP_LAYER_PARTITION: str | None = None + VLLM_CPU_KVCACHE_SPACE: int | None = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" - VLLM_CPU_NUM_OF_RESERVED_CPU: Optional[int] = None + VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None VLLM_CPU_MOE_PREPACK: bool = True VLLM_CPU_SGL_KERNEL: bool = False VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CHECK_RECOMPILATION: bool = False VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING: bool = True - VLLM_USE_RAY_SPMD_WORKER: bool = False - VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto" VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True @@ -73,22 +73,24 @@ VLLM_MM_INPUT_CACHE_GIB: int = 4 VLLM_TARGET_DEVICE: str = "cuda" VLLM_MAIN_CUDA_VERSION: str = "12.8" - MAX_JOBS: Optional[str] = None - NVCC_THREADS: Optional[str] = None + MAX_JOBS: str | None = None + NVCC_THREADS: str | None = None VLLM_USE_PRECOMPILED: bool = False VLLM_DOCKER_BUILD_CONTEXT: bool = False VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: bool = False VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False - CMAKE_BUILD_TYPE: Optional[Literal["Debug", "Release", "RelWithDebInfo"]] = None + CMAKE_BUILD_TYPE: Literal["Debug", "Release", "RelWithDebInfo"] | None = None VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_RPC_TIMEOUT: int = 10000 # ms VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds - VLLM_PLUGINS: Optional[list[str]] = None - VLLM_LORA_RESOLVER_CACHE_DIR: Optional[str] = None - VLLM_TORCH_PROFILER_DIR: Optional[str] = None + VLLM_PLUGINS: list[str] | None = None + VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None + VLLM_TORCH_PROFILER_DIR: str | None = None VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False + VLLM_USE_AOT_COMPILE: bool = False + VLLM_FORCE_AOT_LOAD: bool = False VLLM_TORCH_PROFILER_WITH_STACK: bool = True VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False VLLM_USE_TRITON_AWQ: bool = False @@ -108,6 +110,7 @@ VLLM_ROCM_USE_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False + VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -124,25 +127,29 @@ VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 32 VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_BUNDLE_INDICES: str = "" - VLLM_CUDART_SO_PATH: Optional[str] = None + VLLM_CUDART_SO_PATH: str | None = None VLLM_DP_RANK: int = 0 VLLM_DP_RANK_LOCAL: int = -1 VLLM_DP_SIZE: int = 1 - VLLM_USE_STANDALONE_COMPILE: bool = False + VLLM_USE_STANDALONE_COMPILE: bool = True VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 VLLM_MOE_DP_CHUNK_SIZE: int = 256 VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False + VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict" VLLM_MARLIN_USE_ATOMIC_ADD: bool = False - VLLM_MXFP4_USE_MARLIN: Optional[bool] = None - VLLM_V0_USE_OUTLINES_CACHE: bool = False + VLLM_MXFP4_USE_MARLIN: bool | None = None VLLM_V1_USE_OUTLINES_CACHE: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 - VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None + VLLM_TPU_MOST_MODEL_LEN: int | None = None VLLM_TPU_USING_PATHWAYS: bool = False VLLM_USE_DEEP_GEMM: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True - VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False + VLLM_DEEP_GEMM_WARMUP: Literal[ + "skip", + "full", + "relax", + ] = "relax" VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True VLLM_USE_FLASHINFER_MOE_FP16: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False @@ -166,46 +173,51 @@ VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300 - VLLM_KV_CACHE_LAYOUT: Optional[Literal["NHD", "HND"]] = None + VLLM_KV_CACHE_LAYOUT: Literal["NHD", "HND"] | None = None VLLM_COMPUTE_NANS_IN_LOGITS: bool = False VLLM_USE_NVFP4_CT_EMULATIONS: bool = False VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: Literal[ "FP", "INT8", "INT6", "INT4", "NONE" ] = "NONE" VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True - VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None + VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 VLLM_USE_CUDNN_PREFILL: bool = False + VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False VLLM_LOOPBACK_IP: str = "" VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False - VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None + VLLM_USE_TRTLLM_ATTENTION: str | None = None + VLLM_NVFP4_GEMM_BACKEND: str | None = None VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION: bool = False VLLM_HAS_FLASHINFER_CUBIN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False - VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True - VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None - VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False + VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False + VLLM_TUNED_CONFIG_FOLDER: str | None = None VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER" VLLM_DEEPEP_BUFFER_SIZE_MB: int = 1024 + VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE: bool = False + VLLM_DEEPEP_LOW_LATENCY_ALLOW_NVLINK: bool = False + VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL: bool = False VLLM_DBO_COMM_SMS: int = 20 GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = [] - VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None - VLLM_DEBUG_DUMP_PATH: Optional[str] = None + VLLM_PATTERN_MATCH_DEBUG: str | None = None + VLLM_DEBUG_DUMP_PATH: str | None = None VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE: bool = True VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING: bool = True VLLM_USE_NCCL_SYMM_MEM: bool = False - VLLM_NCCL_INCLUDE_PATH: Optional[str] = None + VLLM_NCCL_INCLUDE_PATH: str | None = None VLLM_USE_FBGEMM: bool = False VLLM_GC_DEBUG: str = "" + VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False def get_default_cache_root(): @@ -222,24 +234,31 @@ def get_default_config_root(): ) -def maybe_convert_int(value: Optional[str]) -> Optional[int]: +def maybe_convert_int(value: str | None) -> int | None: if value is None: return None return int(value) -def maybe_convert_bool(value: Optional[str]) -> Optional[bool]: +def maybe_convert_bool(value: str | None) -> bool | None: if value is None: return None return bool(int(value)) +def use_aot_compile() -> bool: + from vllm.utils.torch_utils import is_torch_equal_or_newer + + default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0" + return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1" + + def env_with_choices( env_name: str, - default: Optional[str], - choices: Union[list[str], Callable[[], list[str]]], + default: str | None, + choices: list[str] | Callable[[], list[str]], case_sensitive: bool = True, -) -> Callable[[], Optional[str]]: +) -> Callable[[], str | None]: """ Create a lambda that validates environment variable against allowed choices @@ -253,7 +272,7 @@ def env_with_choices( Lambda function for environment_variables dict """ - def _get_validated_env() -> Optional[str]: + def _get_validated_env() -> str | None: value = os.getenv(env_name) if value is None: return default @@ -282,7 +301,7 @@ def _get_validated_env() -> Optional[str]: def env_list_with_choices( env_name: str, default: list[str], - choices: Union[list[str], Callable[[], list[str]]], + choices: list[str] | Callable[[], list[str]], case_sensitive: bool = True, ) -> Callable[[], list[str]]: """ @@ -334,7 +353,7 @@ def _get_validated_env_list() -> list[str]: return _get_validated_env_list -def get_vllm_port() -> Optional[int]: +def get_vllm_port() -> int | None: """Get the port from VLLM_PORT environment variable. Returns: @@ -479,10 +498,10 @@ def get_vllm_port() -> Optional[int]: os.environ.get("VLLM_FLASH_ATTN_VERSION", None) ), # Feature flag to enable/disable Inductor standalone compile. - # In torch <= 2.7 we ignore this flag; in torch >= 2.8 this is - # disabled by default. + # In torch <= 2.7 we ignore this flag; in torch >= 2.9 this is + # enabled by default. "VLLM_USE_STANDALONE_COMPILE": lambda: os.environ.get( - "VLLM_USE_STANDALONE_COMPILE", "0" + "VLLM_USE_STANDALONE_COMPILE", "1" ) == "1", # Debug pattern matching inside custom passes. @@ -493,6 +512,14 @@ def get_vllm_port() -> Optional[int]: # Dump fx graphs to the given directory. # It will override CompilationConfig.debug_dump_path if set. "VLLM_DEBUG_DUMP_PATH": lambda: os.environ.get("VLLM_DEBUG_DUMP_PATH", None), + # Feature flag to enable/disable AOT compilation. This will ensure + # compilation is done in warmup phase and the compilation will be + # reused in subsequent calls. + "VLLM_USE_AOT_COMPILE": use_aot_compile, + # Force vllm to always load AOT compiled models from disk. Failure + # to load will result in a hard error when this is enabled. + # Will be ignored when VLLM_USE_AOT_COMPILE is disabled. + "VLLM_FORCE_AOT_LOAD": lambda: os.environ.get("VLLM_FORCE_AOT_LOAD", "0") == "1", # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")), @@ -541,15 +568,6 @@ def get_vllm_port() -> Optional[int]: "VLLM_LOGGING_STREAM": lambda: os.getenv("VLLM_LOGGING_STREAM", "ext://sys.stdout"), # if set, VLLM_LOGGING_PREFIX will be prepended to all log messages "VLLM_LOGGING_PREFIX": lambda: os.getenv("VLLM_LOGGING_PREFIX", ""), - # if set, vllm will call logits processors in a thread pool with this many - # threads. This is useful when using custom logits processors that either - # (a) launch additional CUDA kernels or (b) do significant CPU-bound work - # while not holding the python GIL, or both. - "VLLM_LOGITS_PROCESSOR_THREADS": lambda: int( - os.getenv("VLLM_LOGITS_PROCESSOR_THREADS", "0") - ) - if "VLLM_LOGITS_PROCESSOR_THREADS" in os.environ - else None, # If set, vllm will log stats at this interval in seconds # If not set, vllm will log stats every 10 seconds. "VLLM_LOG_STATS_INTERVAL": lambda: val @@ -608,22 +626,6 @@ def get_vllm_port() -> Optional[int]: "VLLM_CPU_MOE_PREPACK": lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))), # (CPU backend only) whether to use SGL kernels, optimized for small batch. "VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))), - # If the env var is set, then all workers will execute as separate - # processes from the engine, and we use the same mechanism to trigger - # execution on all workers. - # Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it. - "VLLM_USE_RAY_SPMD_WORKER": lambda: bool( - int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0")) - ), - # If the env var is set, it uses the Ray's Compiled Graph - # (previously known as ADAG) API which optimizes the - # control plane overhead. - # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. - # Note that this variable is set to 1 in V1 by default - # when ray distributed executor is used. - "VLLM_USE_RAY_COMPILED_DAG": lambda: bool( - int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0")) - ), # If the env var is set, Ray Compiled Graph uses the specified # channel type to communicate between workers belonging to # different pipeline-parallel stages. @@ -631,20 +633,17 @@ def get_vllm_port() -> Optional[int]: # - "auto": use the default channel type # - "nccl": use NCCL for communication # - "shm": use shared memory and gRPC for communication - # This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set. "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE": env_with_choices( "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "auto", ["auto", "nccl", "shm"] ), # If the env var is set, it enables GPU communication overlap - # (experimental feature) in Ray's Compiled Graph. This flag is ignored if - # VLLM_USE_RAY_COMPILED_DAG is not set. + # (experimental feature) in Ray's Compiled Graph. "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM": lambda: bool( int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "0")) ), # If the env var is set, it uses a Ray Communicator wrapping # vLLM's pipeline parallelism communicator to interact with Ray's # Compiled Graph. Otherwise, it uses Ray's NCCL communicator. - # This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set. "VLLM_USE_RAY_WRAPPED_PP_COMM": lambda: bool( int(os.getenv("VLLM_USE_RAY_WRAPPED_PP_COMM", "1")) ), @@ -889,6 +888,12 @@ def get_vllm_port() -> Optional[int]: os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in ("true", "1") ), + # Whether to use aiter fusion shared experts ops. + # By default is enabled. + "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "True").lower() + in ("true", "1") + ), # use rocm skinny gemms "VLLM_ROCM_USE_SKINNY_GEMM": lambda: ( os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1") @@ -1000,6 +1005,20 @@ def get_vllm_port() -> Optional[int]: "VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0" ) == "1", + # Strategy to pack the data parallel ranks for Ray. + # Available options: + # - "fill": + # for DP master node, allocate exactly data-parallel-size-local DP ranks, + # for non-master nodes, allocate as many DP ranks as can fit; + # - "strict": + # allocate exactly data-parallel-size-local DP ranks to each picked node; + # - "span": + # Should be used only when a single DP rank requires multiple nodes. + # allocate one DP rank over as many nodes as required for set world_size; + # This environment variable is ignored if data-parallel-backend is not Ray. + "VLLM_RAY_DP_PACK_STRATEGY": lambda: os.getenv( + "VLLM_RAY_DP_PACK_STRATEGY", "strict" + ), # Whether to use S3 path for model loading in CI via RunAI Streamer "VLLM_CI_USE_S3": lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1", # Use model_redirect to redirect the model name to a local folder. @@ -1020,13 +1039,6 @@ def get_vllm_port() -> Optional[int]: "VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool( os.environ.get("VLLM_MXFP4_USE_MARLIN", None) ), - # Whether to turn on the outlines cache for V0 - # This cache is unbounded and on disk, so it's not safe to use in - # an environment with potentially malicious users. - "VLLM_V0_USE_OUTLINES_CACHE": lambda: os.environ.get( - "VLLM_V0_USE_OUTLINES_CACHE", "0" - ) - == "1", # Whether to turn on the outlines cache for V1 # This cache is unbounded and on disk, so it's not safe to use in # an environment with potentially malicious users. @@ -1058,9 +1070,21 @@ def get_vllm_port() -> Optional[int]: # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine # startup time by a couple of minutes. - # Set `VLLM_SKIP_DEEP_GEMM_WARMUP` to disable the warmup. - "VLLM_SKIP_DEEP_GEMM_WARMUP": lambda: bool( - int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0")) + # Available options: + # - "skip" : Skip warmup. + # - "full" : Warmup deepgemm by running all possible gemm shapes the + # engine could encounter. + # - "relax" : Select gemm shapes to run based on some heuristics. The + # heuristic aims to have the same effect as running all possible gemm + # shapes, but provides no guarantees. + "VLLM_DEEP_GEMM_WARMUP": env_with_choices( + "VLLM_DEEP_GEMM_WARMUP", + "relax", + [ + "skip", + "full", + "relax", + ], ), # Whether to use fused grouped_topk used for MoE expert selection. "VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool( @@ -1230,6 +1254,10 @@ def get_vllm_port() -> Optional[int]: "VLLM_USE_CUDNN_PREFILL": lambda: bool( int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0")) ), + # Controls whether to use TRT-LLM ragged DeepSeek prefill + "VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL": lambda: bool( + int(os.getenv("VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL", "0")) + ), # If set to 1/True, use the TRTLLM attention backend in flashinfer. # If set to 0/False, use the default attention backend in flashinfer. # If not set, auto-detect the attention backend in flashinfer. @@ -1245,11 +1273,15 @@ def get_vllm_port() -> Optional[int]: # If set, it means we pre-downloaded cubin files and flashinfer will # read the cubin files directly. "VLLM_HAS_FLASHINFER_CUBIN": lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False), - # If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer. - # Otherwise, uses the first available of: flashinfer cutlass GEMM, - # vllm cutlass GEMM, marlin GEMM. - "VLLM_USE_TRTLLM_FP4_GEMM": lambda: bool( - int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0")) + # Supported options: + # - "flashinfer-cudnn": use flashinfer cudnn GEMM backend + # - "flashinfer-trtllm": use flashinfer trtllm GEMM backend + # - "flashinfer-cutlass": use flashinfer cutlass GEMM backend + # - <none>: automatically pick an available backend + "VLLM_NVFP4_GEMM_BACKEND": env_with_choices( + "VLLM_NVFP4_GEMM_BACKEND", + None, + ["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass"], ), # Controls garbage collection during CUDA graph capture. # If set to 0 (default), enables GC freezing to speed up capture time. @@ -1257,12 +1289,6 @@ def get_vllm_port() -> Optional[int]: "VLLM_ENABLE_CUDAGRAPH_GC": lambda: bool( int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0")) ), - # Disable padding to CUDA graph capture batch sizes. - # TODO(wentao): https://github.com/vllm-project/vllm/issues/23378 - # After the issue is fixed, we can remove this flag. - "VLLM_DISABLE_PAD_FOR_CUDAGRAPH": lambda: bool( - int(os.getenv("VLLM_DISABLE_PAD_FOR_CUDAGRAPH", "0")) - ), # Used to force set up loopback IP "VLLM_LOOPBACK_IP": lambda: os.getenv("VLLM_LOOPBACK_IP", ""), # Used to set the process name prefix for vLLM processes. @@ -1297,7 +1323,7 @@ def get_vllm_port() -> Optional[int]: ), # Whether to use pytorch symmetric memory for allreduce "VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool( - int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")) + int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0")) ), # Allows vllm to find tuned config under customized folder "VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), @@ -1327,6 +1353,22 @@ def get_vllm_port() -> Optional[int]: "VLLM_DEEPEP_BUFFER_SIZE_MB": lambda: int( os.getenv("VLLM_DEEPEP_BUFFER_SIZE_MB", "1024") ), + # Force DeepEP to use intranode kernel for inter-node communication in + # high throughput mode. This is useful archive higher prefill throuhgput + # on system supports multi-node nvlink (e.g GB200). + "VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE": lambda: bool( + int(os.getenv("VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE", "0")) + ), + # Allow DeepEP to use nvlink for internode_ll kernel, turn this on for + # better latency on GB200 like system + "VLLM_DEEPEP_LOW_LATENCY_ALLOW_NVLINK": lambda: bool( + int(os.getenv("VLLM_DEEPEP_LOW_LATENCY_ALLOW_NVLINK", "0")) + ), + # Allow DeepEP to use MNNVL (multi-node nvlink) for internode_ll kernel, + # turn this for better latency on GB200 like system + "VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL": lambda: bool( + int(os.getenv("VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL", "0")) + ), # The number of SMs to allocate for communication kernels when running DBO # the rest of the SMs on the device will be allocated to compute "VLLM_DBO_COMM_SMS": lambda: int(os.getenv("VLLM_DBO_COMM_SMS", "20")), @@ -1362,18 +1404,46 @@ def get_vllm_port() -> Optional[int]: # - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with # top 5 collected objects "VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""), + # Disables parallel execution of shared_experts via separate cuda stream + "VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: os.getenv( + "VLLM_DISABLE_SHARED_EXPERTS_STREAM", False + ), } # --8<-- [end:env-vars-definition] def __getattr__(name: str): - # lazy evaluation of environment variables + """ + Gets environment variables lazily. + + NOTE: After enable_envs_cache() invocation (which triggered after service + initialization), all environment variables will be cached. + """ if name in environment_variables: return environment_variables[name]() raise AttributeError(f"module {__name__!r} has no attribute {name!r}") +def enable_envs_cache() -> None: + """ + Enables caching of environment variables. This is useful for performance + reasons, as it avoids the need to re-evaluate environment variables on + every call. + + NOTE: Currently, it's invoked after service initialization to reduce + runtime overhead. This also means that environment variables should NOT + be updated after the service is initialized. + """ + # Tag __getattr__ with functools.cache + global __getattr__ + __getattr__ = functools.cache(__getattr__) + + # Cache all environment variables + for key in environment_variables: + __getattr__(key) + + def __dir__(): return list(environment_variables.keys()) @@ -1427,7 +1497,6 @@ def compute_hash() -> str: "VLLM_DISABLED_KERNELS", "VLLM_USE_DEEP_GEMM", "VLLM_USE_DEEP_GEMM_E8M0", - "VLLM_USE_TRTLLM_FP4_GEMM", "VLLM_USE_FUSED_MOE_GROUPED_TOPK", "VLLM_USE_FLASHINFER_MOE_FP16", "VLLM_USE_FLASHINFER_MOE_FP8", @@ -1436,6 +1505,7 @@ def compute_hash() -> str: "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "VLLM_USE_CUDNN_PREFILL", + "VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL", "VLLM_USE_TRTLLM_ATTENTION", "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "VLLM_ROCM_USE_AITER", @@ -1459,7 +1529,11 @@ def compute_hash() -> str: "VLLM_ROCM_FP8_MFMA_PAGE_ATTN", "VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", "VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING", + "VLLM_NVFP4_GEMM_BACKEND", "VLLM_USE_FBGEMM", + "VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE", + "VLLM_DEEPEP_LOW_LATENCY_ALLOW_NVLINK", + "VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL", ] for key in environment_variables_to_hash: # if this goes out of sync with environment_variables, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py deleted file mode 100644 index 3a7347b8e465..000000000000 --- a/vllm/executor/executor_base.py +++ /dev/null @@ -1,388 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -import time -from abc import ABC, abstractmethod -from collections.abc import Awaitable -from functools import cached_property -from typing import Any, Callable, Optional, Union - -from typing_extensions import TypeVar - -import vllm.platforms -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest -from vllm.tasks import SupportedTask -from vllm.utils import make_async -from vllm.v1.outputs import PoolerOutput, SamplerOutput -from vllm.v1.worker.worker_base import WorkerBase - -logger = init_logger(__name__) - -_R = TypeVar("_R", default=Any) - - -class ExecutorBase(ABC): - """Base class for all executors. - - An executor is responsible for executing the model on one device, - or it can be a distributed executor - that can execute the model on multiple devices. - """ - - uses_ray: bool # whether the executor uses Ray for orchestration. - supports_pp: bool = False # whether the executor supports PP - - def __init__( - self, - vllm_config: VllmConfig, - ) -> None: - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.observability_config = vllm_config.observability_config - self._init_executor() - self.is_sleeping = False - self.sleeping_tags: set[str] = set() - self.kv_output_aggregator = None - - @abstractmethod - def _init_executor(self) -> None: - raise NotImplementedError - - @abstractmethod - def collective_rpc( - self, - method: Union[str, Callable[[WorkerBase], _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None, - ) -> list[_R]: - """ - Execute an RPC call on all workers. - - Args: - method: Name of the worker method to execute, or a callable that - is serialized and sent to all workers to execute. - - If the method is a callable, it should accept an additional - `self` argument, in addition to the arguments passed in `args` - and `kwargs`. The `self` argument will be the worker object. - timeout: Maximum time in seconds to wait for execution. Raises a - [`TimeoutError`][] on timeout. `None` means wait indefinitely. - args: Positional arguments to pass to the worker method. - kwargs: Keyword arguments to pass to the worker method. - - Returns: - A list containing the results from each worker. - - Note: - It is recommended to use this API to only pass control messages, - and set up data-plane communication to pass data. - """ - raise NotImplementedError - - def determine_num_available_blocks(self) -> tuple[int, int]: - """Determine the number of available blocks for the GPU KV cache and - swappable CPU KV cache. - - Normally, this should simply delegate to the underlying Worker. Some - ExecutorBase may require modification of the result, e.g. to ensure the - selected cache sizes are compatible with all workers. - - Returns a tuple `(num_gpu_blocks, num_cpu_blocks)`, where - `num_gpu_blocks` are blocks that are "active" on the device and can be - appended to. - `num_cpu_blocks` refers to "swapped" blocks in CPU memory and cannot be - appended to. - """ - results = self.collective_rpc("determine_num_available_blocks") - a = min([r[0] for r in results]) - b = min([r[1] for r in results]) - return a, b - - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: - """Initialize the KV cache by invoking the underlying worker.""" - # NOTE: This is logged in the executor because there can be >1 workers. - logger.info( - "# %s blocks: %d, # CPU blocks: %d", - vllm.platforms.current_platform.device_name, - num_gpu_blocks, - num_cpu_blocks, - ) - max_concurrency = ( - num_gpu_blocks - * self.cache_config.block_size - / self.model_config.max_model_len - ) - logger.info( - "Maximum concurrency for %s tokens per request: %.2fx", - self.model_config.max_model_len, - max_concurrency, - ) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) - - @cached_property # Avoid unnecessary RPC calls - def supported_tasks(self) -> tuple[SupportedTask, ...]: - output = self.collective_rpc("get_supported_tasks") - return output[0] - - def execute_model( - self, execute_model_req: ExecuteModelRequest - ) -> Optional[list[Union[SamplerOutput, PoolerOutput]]]: - output = self.collective_rpc("execute_model", args=(execute_model_req,)) - return output[0] - - def stop_remote_worker_execution_loop(self) -> None: - """Releases parallel workers from model loop.""" - return - - def add_lora(self, lora_request: LoRARequest) -> bool: - assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return all(self.collective_rpc("add_lora", args=(lora_request,))) - - def remove_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return all(self.collective_rpc("remove_lora", args=(lora_id,))) - - def pin_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return all(self.collective_rpc("pin_lora", args=(lora_id,))) - - def list_loras(self) -> set[int]: - sets = self.collective_rpc("list_loras") - for s in sets: - assert s == sets[0], "All workers should have the same LORAs." - return sets[0] - - def start_profile(self) -> None: - self.collective_rpc("start_profile") - - def stop_profile(self) -> None: - self.collective_rpc("stop_profile") - - def sleep(self, level: int = 1): - if self.is_sleeping: - logger.warning("Executor is already sleeping.") - return - time_before_sleep = time.perf_counter() - self.collective_rpc("sleep", kwargs=dict(level=level)) - time_after_sleep = time.perf_counter() - self.sleeping_tags = {"weights", "kv_cache"} - self.is_sleeping = True - logger.info( - "It took %.6f seconds to fall asleep.", time_after_sleep - time_before_sleep - ) - - def wake_up(self, tags: Optional[list[str]] = None): - if not self.is_sleeping: - logger.warning("Executor is not sleeping.") - return - if tags: - for tag in tags: - if tag not in self.sleeping_tags: - logger.warning( - "Tag %s is not in sleeping tags %s", tag, self.sleeping_tags - ) - return - time_before_wakeup = time.perf_counter() - self.collective_rpc("wake_up", kwargs=dict(tags=tags)) - time_after_wakeup = time.perf_counter() - logger.info( - "It took %.6f seconds to wake up tags %s.", - time_after_wakeup - time_before_wakeup, - tags if tags is not None else self.sleeping_tags, - ) - if tags: - for tag in tags: - self.sleeping_tags.remove(tag) - else: - self.sleeping_tags.clear() - if not self.sleeping_tags: - self.is_sleeping = False - - def save_sharded_state( - self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - self.collective_rpc( - "save_sharded_state", - kwargs=dict(path=path, pattern=pattern, max_size=max_size), - ) - - @abstractmethod - def check_health(self) -> None: - """Checks if the executor is healthy. If not, it should raise an - exception.""" - raise NotImplementedError - - def shutdown(self) -> None: - """Shutdown the executor.""" - self.collective_rpc("shutdown") - - async def execute_model_async( - self, execute_model_req: ExecuteModelRequest - ) -> list[SamplerOutput]: - """Executes one model step on the given sequences.""" - output = await make_async(self.execute_model)(execute_model_req) - return output - - async def stop_remote_worker_execution_loop_async(self) -> None: - """Releases parallel workers from model loop.""" - return - - async def check_health_async(self) -> None: - """Checks if the executor is healthy. If not, it should raise an - exception.""" - self.check_health() - - def init_kv_output_aggregator(self, finished_count: Optional[int]) -> None: - """Init KVOutputAggregator""" - self.kv_output_aggregator = KVOutputAggregator( - finished_count or self.parallel_config.world_size - ) - - -class DistributedExecutorBase(ExecutorBase): - """Abstract superclass of distributed executor implementations.""" - - def __init__(self, *args, **kwargs): - # This is non-None when the execute model loop is running - # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. - self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None - - super().__init__(*args, **kwargs) - - def execute_model( - self, - execute_model_req: ExecuteModelRequest, - ) -> list[SamplerOutput]: - # TODO: unify into collective_rpc - if self.parallel_worker_tasks is None: - self.parallel_worker_tasks = self._run_workers( - "start_worker_execution_loop", - async_run_tensor_parallel_workers_only=True, - ) - - # Only the driver worker returns the sampling results. - driver_outputs = self._driver_execute_model(execute_model_req) - assert driver_outputs is not None - return driver_outputs - - def stop_remote_worker_execution_loop(self) -> None: - if self.parallel_worker_tasks is None: - return - - self._driver_execute_model(execute_model_req=None) - parallel_worker_tasks = self.parallel_worker_tasks - self.parallel_worker_tasks = None - # Ensure that workers exit model loop cleanly - # (this will raise otherwise) - self._wait_for_tasks_completion(parallel_worker_tasks) - - @abstractmethod - def _driver_execute_model( - self, execute_model_req: Optional[ExecuteModelRequest] - ) -> Optional[list[SamplerOutput]]: - """Run execute_model in the driver worker. - - Passing None will cause the driver to stop the model execution loop - running in each of the remote workers. In this case, this method - returns None. Otherwise, this method returns the model output. - """ - raise NotImplementedError - - def collective_rpc( - self, - method: Union[str, Callable], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None, - ) -> list[Any]: - return self._run_workers(method, *args, **(kwargs or {})) - - @abstractmethod - def _run_workers( - self, - method: Union[str, Callable], - *args, - async_run_tensor_parallel_workers_only: bool = False, - max_concurrent_workers: Optional[int] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers. - - Args: - async_run_tensor_parallel_workers_only: If True the method will be - run only in the remote TP workers, not the driver worker. - It will also be run asynchronously and return a list of futures - rather than blocking on the results. - - # TODO: simplify and merge with collective_rpc - """ - raise NotImplementedError - - @abstractmethod - def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: - """Wait for futures returned from _run_workers() with - async_run_remote_workers_only to complete.""" - raise NotImplementedError - - async def execute_model_async( - self, execute_model_req: ExecuteModelRequest - ) -> list[SamplerOutput]: - if self.parallel_worker_tasks is None: - # Start model execution loop running in the parallel workers - self.parallel_worker_tasks = asyncio.create_task( - self._start_worker_execution_loop() - ) - - # Only the driver worker returns the sampling results. - return await self._driver_execute_model_async(execute_model_req) - - async def stop_remote_worker_execution_loop_async(self) -> None: - if self.parallel_worker_tasks is None: - return - - await self._driver_execute_model_async() - parallel_worker_tasks = self.parallel_worker_tasks - self.parallel_worker_tasks = None - # Ensure that workers exit model loop cleanly - # (this will raise otherwise) - await parallel_worker_tasks - - @abstractmethod - async def _driver_execute_model_async( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> list[SamplerOutput]: - """Execute the model asynchronously in the driver worker. - - Passing None will cause the driver to stop the model execution - loop running in each of the remote workers. - """ - raise NotImplementedError - - @abstractmethod - async def _start_worker_execution_loop(self): - """Run execution loop on all workers. It guarantees all workers run - the loop or None of them is running the loop. Loop can be stopped by - `stop_remote_worker_execution_loop`. - The API is idempotent (guarantee only 1 loop run at any moment).""" - raise NotImplementedError diff --git a/vllm/executor/msgspec_utils.py b/vllm/executor/msgspec_utils.py deleted file mode 100644 index ac16f06b160e..000000000000 --- a/vllm/executor/msgspec_utils.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from array import array -from typing import Any - -from vllm.multimodal.inputs import MultiModalKwargs -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE - - -def encode_hook(obj: Any) -> Any: - """Custom msgspec enc hook that supports array types and MultiModalKwargs. - - See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder - """ - if isinstance(obj, array): - assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, ( - f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. " - f"Given array has a type code of {obj.typecode}." - ) - return obj.tobytes() - if isinstance(obj, MultiModalKwargs): - return dict(obj) - - -def decode_hook(type: type, obj: Any) -> Any: - """Custom msgspec dec hook that supports array types and MultiModalKwargs. - - See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder - """ - if type is array: - deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE) - deserialized.frombytes(obj) - return deserialized - if type is MultiModalKwargs: - return MultiModalKwargs(obj) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index a6a1e36bfe95..ef37cf862c9f 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -5,13 +5,14 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Union import torch import vllm.envs as envs from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.logger import init_logger +from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.ubatch_utils import UBatchSlices if TYPE_CHECKING: @@ -39,13 +40,19 @@ class BatchDescriptor(NamedTuple): False can also be used for an uniform decode batch to dispatch to the cudagraph supporting non-uniform batches. """ + has_lora: bool = False + """ + Whether this batch has active LoRA adapters. + """ @property def non_uniform(self) -> "BatchDescriptor": """ Return a non-uniform version of current batch descriptor. """ - return BatchDescriptor(self.num_tokens, uniform_decode=False) + return BatchDescriptor( + self.num_tokens, uniform_decode=False, has_lora=self.has_lora + ) def _compute_sp_num_tokens( @@ -83,7 +90,7 @@ class DPMetadata: num_tokens_across_dp_cpu: torch.Tensor # NOTE: local_sizes should only be set by the chunked_sizes context manager - local_sizes: Optional[list[int]] = None + local_sizes: list[int] | None = None @staticmethod def make( @@ -157,10 +164,21 @@ def sp_local_sizes(self, sequence_parallel_size: int): finally: self.local_sizes = None - def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]: + def get_chunk_sizes_across_dp_rank(self) -> list[int] | None: assert self.local_sizes is not None return self.local_sizes + # Get the cumulative tokens across sequence parallel ranks. + # In this case the input to the MoEs will be distributed w.r.t both + # DP and TP rank. + # When sp_size==1, this is just the cummulative num tokens across DP. + def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor: + num_tokens_across_sp_cpu = ( + self.num_tokens_across_dp_cpu - 1 + sp_size + ) // sp_size + num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size) + return torch.cumsum(num_tokens_across_sp_cpu, dim=0) + @dataclass class ForwardContext: @@ -182,13 +200,13 @@ class ForwardContext: # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass - dp_metadata: Optional[DPMetadata] = None + dp_metadata: DPMetadata | None = None # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE. # by default NONE, no cudagraph is used. cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE - batch_descriptor: Optional[BatchDescriptor] = None + batch_descriptor: BatchDescriptor | None = None - ubatch_slices: Optional[UBatchSlices] = None + ubatch_slices: UBatchSlices | None = None def __post_init__(self): assert self.cudagraph_runtime_mode.valid_runtime_modes(), ( @@ -196,7 +214,7 @@ def __post_init__(self): ) -_forward_context: Optional[ForwardContext] = None +_forward_context: ForwardContext | None = None def get_forward_context() -> ForwardContext: @@ -212,10 +230,10 @@ def create_forward_context( attn_metadata: Any, vllm_config: VllmConfig, virtual_engine: int = 0, - dp_metadata: Optional[DPMetadata] = None, + dp_metadata: DPMetadata | None = None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor: Optional[BatchDescriptor] = None, - ubatch_slices: Optional[UBatchSlices] = None, + batch_descriptor: BatchDescriptor | None = None, + ubatch_slices: UBatchSlices | None = None, ): return ForwardContext( no_compile_layers=vllm_config.compilation_config.static_forward_context, @@ -229,7 +247,7 @@ def create_forward_context( @contextmanager -def override_forward_context(forward_context: Optional[ForwardContext]): +def override_forward_context(forward_context: ForwardContext | None): """A context manager that overrides the current forward context. This is used to override the forward context for a specific forward pass. @@ -248,11 +266,11 @@ def set_forward_context( attn_metadata: Any, vllm_config: VllmConfig, virtual_engine: int = 0, - num_tokens: Optional[int] = None, - num_tokens_across_dp: Optional[torch.Tensor] = None, + num_tokens: int | None = None, + num_tokens_across_dp: torch.Tensor | None = None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor: Optional[BatchDescriptor] = None, - ubatch_slices: Optional[UBatchSlices] = None, + batch_descriptor: BatchDescriptor | None = None, + ubatch_slices: UBatchSlices | None = None, ): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -263,15 +281,33 @@ def set_forward_context( if need_to_track_batchsize: forward_start_time = time.perf_counter() - dp_metadata: Optional[DPMetadata] = None + dp_metadata: DPMetadata | None = None if vllm_config.parallel_config.data_parallel_size > 1 and ( attn_metadata is not None or num_tokens is not None ): - assert num_tokens_across_dp is not None + # If num_tokens_across_dp hasn't already been initialized, then + # initialize it here. Both DP padding and Microbatching will be + # disabled. + if num_tokens_across_dp is None: + assert ubatch_slices is None + assert num_tokens is not None + _, num_tokens_across_dp = coordinate_batch_across_dp( + num_tokens_unpadded=num_tokens, + parallel_config=vllm_config.parallel_config, + allow_microbatching=False, + allow_dp_padding=False, + ) + assert num_tokens_across_dp is not None dp_metadata = DPMetadata.make( vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp ) + # Convenience: if cudagraph is used and num_tokens is given, we can just + # create a batch descriptor here if not given (there's no harm since if it + # doesn't match in the wrapper it'll fall through). + if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None: + batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens) + forward_context = create_forward_context( attn_metadata, vllm_config, diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index c463723e5d0e..1f138a72d084 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast import torch from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar @@ -12,6 +12,10 @@ MultiModalInputs, MultiModalUUIDDict, ) +else: + MultiModalDataDict = object + MultiModalInputs = object + MultiModalUUIDDict = object class TextPrompt(TypedDict): @@ -20,13 +24,13 @@ class TextPrompt(TypedDict): prompt: str """The input text to be tokenized before passing to the model.""" - multi_modal_data: NotRequired["MultiModalDataDict"] + multi_modal_data: NotRequired[MultiModalDataDict | None] """ Optional multi-modal data to pass to the model, if the model supports it. """ - mm_processor_kwargs: NotRequired[dict[str, Any]] + mm_processor_kwargs: NotRequired[dict[str, Any] | None] """ Optional multi-modal processor kwargs to be forwarded to the multimodal input mapper & processor. Note that if multiple modalities @@ -34,7 +38,7 @@ class TextPrompt(TypedDict): to pass the mm_processor_kwargs to each of them. """ - multi_modal_uuids: NotRequired["MultiModalUUIDDict"] + multi_modal_uuids: NotRequired[MultiModalUUIDDict] """ Optional user-specified UUIDs for multimodal items, mapped by modality. Lists must match the number of items per modality and may contain `None`. @@ -61,13 +65,13 @@ class TokensPrompt(TypedDict): token_type_ids: NotRequired[list[int]] """A list of token type IDs to pass to the cross encoder model.""" - multi_modal_data: NotRequired["MultiModalDataDict"] + multi_modal_data: NotRequired[MultiModalDataDict | None] """ Optional multi-modal data to pass to the model, if the model supports it. """ - mm_processor_kwargs: NotRequired[dict[str, Any]] + mm_processor_kwargs: NotRequired[dict[str, Any] | None] """ Optional multi-modal processor kwargs to be forwarded to the multimodal input mapper & processor. Note that if multiple modalities @@ -75,7 +79,7 @@ class TokensPrompt(TypedDict): to pass the mm_processor_kwargs to each of them. """ - multi_modal_uuids: NotRequired["MultiModalUUIDDict"] + multi_modal_uuids: NotRequired[MultiModalUUIDDict] """ Optional user-specified UUIDs for multimodal items, mapped by modality. Lists must match the number of items per modality and may contain `None`. @@ -111,7 +115,7 @@ class DataPrompt(TypedDict): """The input data format""" -SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] +SingletonPrompt: TypeAlias = str | TextPrompt | TokensPrompt | EmbedsPrompt """ Set of possible schemas for a single prompt: @@ -185,12 +189,12 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): encoder_prompt: _T1_co - decoder_prompt: Optional[_T2_co] + decoder_prompt: _T2_co | None mm_processor_kwargs: NotRequired[dict[str, Any]] -PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] +PromptType: TypeAlias = SingletonPrompt | ExplicitEncoderDecoderPrompt """ Set of possible schemas for an LLM input, including both decoder-only and encoder/decoder input types: @@ -220,7 +224,7 @@ class TokenInputs(TypedDict): def token_inputs( prompt_token_ids: list[int], - cache_salt: Optional[str] = None, + cache_salt: str | None = None, ) -> TokenInputs: """Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional values.""" @@ -249,7 +253,7 @@ class EmbedsInputs(TypedDict): def embeds_inputs( prompt_embeds: torch.Tensor, - cache_salt: Optional[str] = None, + cache_salt: str | None = None, ) -> EmbedsInputs: """Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional values.""" @@ -261,7 +265,7 @@ def embeds_inputs( return inputs -DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"] +DecoderOnlyInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs """ The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they are passed to the model executor. @@ -277,20 +281,20 @@ class EncoderDecoderInputs(TypedDict): This specifies the required data for encoder-decoder models. """ - encoder: Union[TokenInputs, "MultiModalInputs"] + encoder: TokenInputs | MultiModalInputs """The inputs for the encoder portion.""" - decoder: Union[TokenInputs, "MultiModalInputs"] + decoder: TokenInputs | MultiModalInputs """The inputs for the decoder portion.""" -SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"] +SingletonInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs """ A processed [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] which can be passed to [`Sequence`][collections.abc.Sequence]. """ -ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs] +ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs """ The outputs from [`vllm.inputs.preprocess.InputPreprocessor`][]. """ @@ -301,8 +305,8 @@ class EncoderDecoderInputs(TypedDict): def build_explicit_enc_dec_prompt( encoder_prompt: _T1, - decoder_prompt: Optional[_T2], - mm_processor_kwargs: Optional[dict[str, Any]] = None, + decoder_prompt: _T2 | None, + mm_processor_kwargs: dict[str, Any] | None = None, ) -> ExplicitEncoderDecoderPrompt[_T1, _T2]: if mm_processor_kwargs is None: mm_processor_kwargs = {} @@ -315,17 +319,15 @@ def build_explicit_enc_dec_prompt( def zip_enc_dec_prompts( enc_prompts: Iterable[_T1], - dec_prompts: Iterable[Optional[_T2]], - mm_processor_kwargs: Optional[ - Union[Iterable[dict[str, Any]], dict[str, Any]] - ] = None, + dec_prompts: Iterable[_T2 | None], + mm_processor_kwargs: Iterable[dict[str, Any]] | dict[str, Any] | None = None, ) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]: """ Zip encoder and decoder prompts together into a list of [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt] instances. - ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same + `mm_processor_kwargs` may also be provided; if a dict is passed, the same dictionary will be used for every encoder/decoder prompt. If an iterable is provided, it will be zipped with the encoder/decoder prompts. """ @@ -350,7 +352,7 @@ def zip_enc_dec_prompts( def to_enc_dec_tuple_list( enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]], -) -> list[tuple[_T1, Optional[_T2]]]: +) -> list[tuple[_T1, _T2 | None]]: return [ (enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"]) for enc_dec_prompt in enc_dec_prompts diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 2f7bd50df022..211551be8e60 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict, Union, cast +from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict, cast from typing_extensions import TypeIs -from vllm.utils import is_list_of +from vllm.utils.collection_utils import is_list_of from .data import ( EmbedsPrompt, @@ -23,8 +23,8 @@ def parse_raw_prompts( - prompt: Union[str, list[str], list[int], list[list[int]]], -) -> Union[Sequence[TextPrompt], Sequence[TokensPrompt]]: + prompt: str | list[str] | list[int] | list[list[int]], +) -> Sequence[TextPrompt] | Sequence[TokensPrompt]: if isinstance(prompt, str): # case 1: a string return [TextPrompt(prompt=prompt)] @@ -76,9 +76,9 @@ class ParsedEmbedsPrompt(TypedDict): content: EmbedsPrompt -ParsedSingletonPrompt = Union[ - ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt, ParsedEmbedsPrompt -] +ParsedSingletonPrompt: TypeAlias = ( + ParsedStrPrompt | ParsedTextPrompt | ParsedTokensPrompt | ParsedEmbedsPrompt +) def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt: @@ -106,7 +106,7 @@ def is_explicit_encoder_decoder_prompt( def split_enc_dec_inputs( inputs: ProcessorInputs, -) -> tuple[Optional[SingletonInputs], SingletonInputs]: +) -> tuple[SingletonInputs | None, SingletonInputs]: if "encoder" in inputs and "decoder" in inputs: # NOTE: This passes pyright but not mypy return ( @@ -118,9 +118,9 @@ def split_enc_dec_inputs( class PromptComponents(NamedTuple): - text: Optional[str] = None - token_ids: Optional[list[int]] = None - embeds: Optional["torch.Tensor"] = None + text: str | None = None + token_ids: list[int] | None = None + embeds: "torch.Tensor | None" = None def get_prompt_components(prompt: PromptType) -> PromptComponents: diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 00f30e483693..12363d4f1f9f 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping -from typing import Any, Optional, Union, cast +from typing import Any, cast from typing_extensions import assert_never @@ -17,8 +17,9 @@ MultiModalUUIDDict, ) from vllm.multimodal.processing import BaseMultiModalProcessor -from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils.jsontree import json_iter_leaves +from vllm.v1.metrics.stats import MultiModalCacheStats from .data import ( DecoderOnlyInputs, @@ -45,19 +46,18 @@ class InputPreprocessor: def __init__( self, model_config: ModelConfig, + tokenizer: AnyTokenizer | None, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None, + mm_processor_cache: BaseMultiModalProcessorCache | None = None, ) -> None: super().__init__() self.model_config = model_config + self.tokenizer = tokenizer self.mm_registry = mm_registry self.mm_processor_cache = mm_processor_cache - if model_config.skip_tokenizer_init: - self.tokenizer = None - else: - self.tokenizer = init_tokenizer_from_configs(model_config) + self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: @@ -67,25 +67,25 @@ def get_tokenizer(self) -> AnyTokenizer: return self.tokenizer - def get_bos_token_id(self) -> Optional[int]: + def get_bos_token_id(self) -> int | None: if self.tokenizer is None: - logger.warning( + logger.warning_once( "Using None for BOS token id because tokenizer is not initialized" ) return None return self.tokenizer.bos_token_id - def get_eos_token_id(self) -> Optional[int]: + def get_eos_token_id(self) -> int | None: if self.tokenizer is None: - logger.warning( + logger.warning_once( "Using None for EOS token id because tokenizer is not initialized" ) return None return self.tokenizer.eos_token_id - def get_decoder_start_token_id(self) -> Optional[int]: + def get_decoder_start_token_id(self) -> int | None: """ Obtain the decoder start token id employed by an encoder/decoder model. Returns None for non-encoder/decoder models or if the @@ -157,7 +157,7 @@ def _get_default_enc_dec_decoder_prompt(self) -> list[int]: def _prepare_decoder_input_ids_for_generation( self, - decoder_input_ids: Optional[list[int]], + decoder_input_ids: list[int] | None, ) -> list[int]: """ Prepares `decoder_input_ids` for generation with encoder-decoder models. @@ -194,7 +194,7 @@ def _prepare_decoder_input_ids_for_generation( def _get_tokenization_kw( self, - overrides: Optional[dict[str, Any]] = None, + overrides: dict[str, Any] | None = None, ) -> dict[str, Any]: kwargs = dict[str, Any]() @@ -212,7 +212,7 @@ def _get_tokenization_kw( def _tokenize_prompt( self, prompt: str, - tokenization_kwargs: Optional[dict[str, Any]] = None, + tokenization_kwargs: dict[str, Any] | None = None, ) -> list[int]: """ Apply the model's tokenizer to a text prompt, returning the @@ -221,6 +221,13 @@ def _tokenize_prompt( tokenizer = self.get_tokenizer() tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs) + bos_token_text = getattr(tokenizer, "bos_token", None) + if (bos_token_text and isinstance(bos_token_text, str) and + prompt.lstrip().startswith(bos_token_text) and + "add_special_tokens" not in tokenization_kwargs): + # override if not explicitly set by caller. + tokenization_kwargs["add_special_tokens"] = False + encoder_config = self.model_config.encoder_config if encoder_config and encoder_config.get("do_lower_case", False): @@ -251,12 +258,12 @@ def _get_mm_processor(self) -> BaseMultiModalProcessor: def _process_multimodal( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_data: MultiModalDataDict, - mm_processor_kwargs: Optional[Mapping[str, object]], - tokenization_kwargs: Optional[dict[str, Any]] = None, + mm_processor_kwargs: Mapping[str, object] | None, + tokenization_kwargs: dict[str, Any] | None = None, *, - mm_uuids: Optional[MultiModalUUIDDict] = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> MultiModalInputs: """ Apply the model's multi-modal processor to a multi-modal prompt, @@ -320,7 +327,7 @@ def _process_embeds( ) def _truncate_inputs( - self, inputs: list[int], tokenization_kwargs: Optional[dict[str, Any]] = None + self, inputs: list[int], tokenization_kwargs: dict[str, Any] | None = None ) -> list[int]: if ( not tokenization_kwargs @@ -339,20 +346,20 @@ def _truncate_inputs( def _process_tokens( self, parsed_content: TokensPrompt, - tokenization_kwargs: Optional[dict[str, Any]] = None, + tokenization_kwargs: dict[str, Any] | None = None, *, - mm_uuids: Optional[MultiModalUUIDDict] = None, - ) -> Union[TokenInputs, MultiModalInputs]: + mm_uuids: MultiModalUUIDDict | None = None, + ) -> TokenInputs | MultiModalInputs: prompt_token_ids = self._truncate_inputs( parsed_content["prompt_token_ids"], tokenization_kwargs ) - inputs: Union[TokenInputs, MultiModalInputs] + inputs: TokenInputs | MultiModalInputs if self.model_config.is_multimodal_model: inputs = self._process_multimodal( prompt_token_ids, - parsed_content.get("multi_modal_data", {}), - parsed_content.get("mm_processor_kwargs"), + parsed_content.get("multi_modal_data") or {}, + parsed_content.get("mm_processor_kwargs") or {}, tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, ) @@ -370,18 +377,18 @@ def _process_tokens( def _process_text( self, parsed_content: TextPrompt, - tokenization_kwargs: Optional[dict[str, Any]] = None, + tokenization_kwargs: dict[str, Any] | None = None, *, - mm_uuids: Optional[MultiModalUUIDDict] = None, - ) -> Union[TokenInputs, MultiModalInputs]: + mm_uuids: MultiModalUUIDDict | None = None, + ) -> TokenInputs | MultiModalInputs: prompt_text = parsed_content["prompt"] - inputs: Union[TokenInputs, MultiModalInputs] + inputs: TokenInputs | MultiModalInputs if self.model_config.is_multimodal_model: inputs = self._process_multimodal( prompt_text, - parsed_content.get("multi_modal_data", {}), - parsed_content.get("mm_processor_kwargs"), + parsed_content.get("multi_modal_data") or {}, + parsed_content.get("mm_processor_kwargs") or {}, tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, ) @@ -403,9 +410,9 @@ def _process_text( def _prompt_to_llm_inputs( self, prompt: SingletonPrompt, - tokenization_kwargs: Optional[dict[str, Any]] = None, + tokenization_kwargs: dict[str, Any] | None = None, *, - mm_uuids: Optional[MultiModalUUIDDict] = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> SingletonInputs: """ Extract the singleton inputs from a prompt. @@ -445,7 +452,7 @@ def _prompt_to_llm_inputs( def _build_enc_dec_llm_inputs( self, encoder_inputs: SingletonInputs, - decoder_inputs: Optional[SingletonInputs], + decoder_inputs: SingletonInputs | None, ) -> EncoderDecoderInputs: if ( encoder_inputs["type"] == "embeds" @@ -457,10 +464,8 @@ def _build_enc_dec_llm_inputs( ) # Needed for mypy - encoder_inputs = cast(Union[TokenInputs, MultiModalInputs], encoder_inputs) - decoder_inputs = cast( - Optional[Union[TokenInputs, MultiModalInputs]], decoder_inputs - ) + encoder_inputs = cast(TokenInputs | MultiModalInputs, encoder_inputs) + decoder_inputs = cast(TokenInputs | MultiModalInputs | None, decoder_inputs) if decoder_inputs is None: if self.model_config.hf_config.model_type == "whisper": @@ -491,8 +496,8 @@ def _build_enc_dec_llm_inputs( def _split_enc_dec_mm_inputs( self, - inputs: Union[SingletonInputs, MultiModalEncDecInputs], - decoder_inputs_to_override: Optional[SingletonInputs] = None, + inputs: SingletonInputs | MultiModalEncDecInputs, + decoder_inputs_to_override: SingletonInputs | None = None, ) -> tuple[SingletonInputs, SingletonInputs]: """ For encoder/decoder models only: @@ -509,11 +514,11 @@ def _split_enc_dec_mm_inputs( # Needed for mypy inputs = cast( - Union[TokenInputs, MultiModalInputs, MultiModalEncDecInputs], + TokenInputs | MultiModalInputs | MultiModalEncDecInputs, inputs, ) decoder_inputs_to_override = cast( - Optional[Union[TokenInputs, MultiModalInputs]], + TokenInputs | MultiModalInputs | None, decoder_inputs_to_override, ) @@ -553,9 +558,9 @@ def _split_enc_dec_mm_inputs( def _process_encoder_decoder_prompt( self, prompt: PromptType, - tokenization_kwargs: Optional[dict[str, Any]] = None, + tokenization_kwargs: dict[str, Any] | None = None, *, - mm_uuids: Optional[MultiModalUUIDDict] = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> EncoderDecoderInputs: """ For encoder/decoder models only: @@ -591,7 +596,7 @@ def _process_encoder_decoder_prompt( instance """ encoder_inputs: SingletonInputs - decoder_inputs: Optional[SingletonInputs] + decoder_inputs: SingletonInputs | None if is_explicit_encoder_decoder_prompt(prompt): # `cast` is needed for mypy, but not pyright @@ -633,7 +638,7 @@ def _build_decoder_only_llm_inputs( ) -> DecoderOnlyInputs: if "prompt_token_ids" in prompt_inputs: prompt_inputs = cast( - Union[TokenInputs, MultiModalInputs], prompt_inputs + TokenInputs | MultiModalInputs, prompt_inputs ) # Needed for mypy return prompt_inputs @@ -641,9 +646,9 @@ def _build_decoder_only_llm_inputs( def _process_decoder_only_prompt( self, prompt: SingletonPrompt, - tokenization_kwargs: Optional[dict[str, Any]] = None, + tokenization_kwargs: dict[str, Any] | None = None, *, - mm_uuids: Optional[MultiModalUUIDDict] = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> DecoderOnlyInputs: """ For decoder-only models: @@ -667,14 +672,13 @@ def _process_decoder_only_prompt( return self._build_decoder_only_llm_inputs(prompt_comps) - def preprocess( + def _preprocess( self, prompt: PromptType, - tokenization_kwargs: Optional[dict[str, Any]] = None, + tokenization_kwargs: dict[str, Any] | None = None, *, - mm_uuids: Optional[MultiModalUUIDDict] = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> ProcessorInputs: - """Preprocess the input prompt.""" if self.model_config.is_encoder_decoder: # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder. @@ -697,6 +701,40 @@ def preprocess( mm_uuids=mm_uuids, ) - def clear_cache(self) -> None: + def preprocess( + self, + prompt: PromptType, + tokenization_kwargs: dict[str, Any] | None = None, + *, + mm_uuids: MultiModalUUIDDict | None = None, + ) -> ProcessorInputs: + """Preprocess the input prompt.""" + res = self._preprocess( + prompt, + tokenization_kwargs, + mm_uuids=mm_uuids, + ) + + if self.mm_processor_cache and self.mm_cache_stats is not None: + delta = self.mm_processor_cache.make_stats(delta=True) + self.mm_cache_stats.requests += 1 + self.mm_cache_stats.queries += delta.total + self.mm_cache_stats.hits += delta.hits + + return res + + def stat_mm_cache(self) -> MultiModalCacheStats | None: + mm_cache_stats = self.mm_cache_stats + if mm_cache_stats is None: + return None + + self.mm_cache_stats = MultiModalCacheStats() + + return mm_cache_stats + + def clear_mm_cache(self) -> None: if self.mm_processor_cache is not None: self.mm_processor_cache.clear_cache() + + if self.mm_cache_stats is not None: + self.mm_cache_stats.reset = True diff --git a/vllm/logger.py b/vllm/logger.py index 37e8495768c0..934100829684 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -13,7 +13,7 @@ from logging.config import dictConfig from os import path from types import MethodType -from typing import Any, Optional, cast +from typing import Any, Literal, cast import vllm.envs as envs @@ -59,20 +59,37 @@ @lru_cache def _print_debug_once(logger: Logger, msg: str, *args: Hashable) -> None: - # Set the stacklevel to 2 to print the original caller's line info - logger.debug(msg, *args, stacklevel=2) + # Set the stacklevel to 3 to print the original caller's line info + logger.debug(msg, *args, stacklevel=3) @lru_cache def _print_info_once(logger: Logger, msg: str, *args: Hashable) -> None: - # Set the stacklevel to 2 to print the original caller's line info - logger.info(msg, *args, stacklevel=2) + # Set the stacklevel to 3 to print the original caller's line info + logger.info(msg, *args, stacklevel=3) @lru_cache def _print_warning_once(logger: Logger, msg: str, *args: Hashable) -> None: - # Set the stacklevel to 2 to print the original caller's line info - logger.warning(msg, *args, stacklevel=2) + # Set the stacklevel to 3 to print the original caller's line info + logger.warning(msg, *args, stacklevel=3) + + +LogScope = Literal["process", "global", "local"] + + +def _should_log_with_scope(scope: LogScope) -> bool: + """Decide whether to log based on scope""" + if scope == "global": + from vllm.distributed.parallel_state import is_global_first_rank + + return is_global_first_rank() + if scope == "local": + from vllm.distributed.parallel_state import is_local_first_rank + + return is_local_first_rank() + # default "process" scope: always log + return True class _VllmLogger(Logger): @@ -84,33 +101,43 @@ class _VllmLogger(Logger): `intel_extension_for_pytorch.utils._logger`. """ - def debug_once(self, msg: str, *args: Hashable) -> None: + def debug_once( + self, msg: str, *args: Hashable, scope: LogScope = "process" + ) -> None: """ As [`debug`][logging.Logger.debug], but subsequent calls with the same message are silently dropped. """ + if not _should_log_with_scope(scope): + return _print_debug_once(self, msg, *args) - def info_once(self, msg: str, *args: Hashable) -> None: + def info_once(self, msg: str, *args: Hashable, scope: LogScope = "process") -> None: """ As [`info`][logging.Logger.info], but subsequent calls with the same message are silently dropped. """ + if not _should_log_with_scope(scope): + return _print_info_once(self, msg, *args) - def warning_once(self, msg: str, *args: Hashable) -> None: + def warning_once( + self, msg: str, *args: Hashable, scope: LogScope = "process" + ) -> None: """ As [`warning`][logging.Logger.warning], but subsequent calls with the same message are silently dropped. """ + if not _should_log_with_scope(scope): + return _print_warning_once(self, msg, *args) # Pre-defined methods mapping to avoid repeated dictionary creation _METHODS_TO_PATCH = { - "debug_once": _print_debug_once, - "info_once": _print_info_once, - "warning_once": _print_warning_once, + "debug_once": _VllmLogger.debug_once, + "info_once": _VllmLogger.info_once, + "warning_once": _VllmLogger.warning_once, } @@ -217,7 +244,7 @@ def _trace_calls(log_path, root_dir, frame, event, arg=None): return partial(_trace_calls, log_path, root_dir) -def enable_trace_function_call(log_file_path: str, root_dir: Optional[str] = None): +def enable_trace_function_call(log_file_path: str, root_dir: str | None = None): """ Enable tracing of every function call in code under `root_dir`. This is useful for debugging hangs or crashes. diff --git a/vllm/logging_utils/dump_input.py b/vllm/logging_utils/dump_input.py index 3a97000647d6..cb289d04e3f4 100644 --- a/vllm/logging_utils/dump_input.py +++ b/vllm/logging_utils/dump_input.py @@ -4,7 +4,6 @@ import contextlib import enum import json -from typing import Optional import torch @@ -57,7 +56,7 @@ def prepare_object_to_dump(obj) -> str: def dump_engine_exception( config: VllmConfig, scheduler_output: SchedulerOutput, - scheduler_stats: Optional[SchedulerStats], + scheduler_stats: SchedulerStats | None, ): # NOTE: ensure we can log extra info without risking raises # unexpected errors during logging @@ -68,7 +67,7 @@ def dump_engine_exception( def _dump_engine_exception( config: VllmConfig, scheduler_output: SchedulerOutput, - scheduler_stats: Optional[SchedulerStats], + scheduler_stats: SchedulerStats | None, ): logger.error( "Dumping input data for V1 LLM engine (v%s) with config: %s, ", diff --git a/vllm/logits_process.py b/vllm/logits_process.py index 6ac30ae0028e..7b6a6528e20e 100644 --- a/vllm/logits_process.py +++ b/vllm/logits_process.py @@ -1,16 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Sequence -from typing import Callable, Union +from collections.abc import Callable, Sequence +from typing import TypeAlias import torch from vllm.transformers_utils.tokenizer import AnyTokenizer -LogitsProcessor = Union[ - Callable[[list[int], torch.Tensor], torch.Tensor], - Callable[[list[int], list[int], torch.Tensor], torch.Tensor], -] +LogitsProcessor: TypeAlias = ( + Callable[[list[int], torch.Tensor], torch.Tensor] + | Callable[[list[int], list[int], torch.Tensor], torch.Tensor] +) """LogitsProcessor is a function that takes a list of previously generated tokens, the logits tensor for the next token and, optionally, prompt tokens as a diff --git a/vllm/logprobs.py b/vllm/logprobs.py index 2458e43c690f..21c886e0ad5e 100644 --- a/vllm/logprobs.py +++ b/vllm/logprobs.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional # We use dataclass for now because it is used for @@ -18,12 +17,12 @@ class Logprob: """ logprob: float - rank: Optional[int] = None - decoded_token: Optional[str] = None + rank: int | None = None + decoded_token: str | None = None # {token_id -> logprob} per each sequence group. None if the corresponding # sequence group doesn't require prompt logprob. -PromptLogprobs = list[Optional[dict[int, Logprob]]] +PromptLogprobs = list[dict[int, Logprob] | None] # {token_id -> logprob} for each sequence group. SampleLogprobs = list[dict[int, Logprob]] diff --git a/vllm/lora/layers/__init__.py b/vllm/lora/layers/__init__.py index 4915ef85f4f7..8a4f5ff175d4 100644 --- a/vllm/lora/layers/__init__.py +++ b/vllm/lora/layers/__init__.py @@ -11,6 +11,7 @@ QKVParallelLinearWithLoRA, QKVParallelLinearWithShardedLoRA, ) +from vllm.lora.layers.fused_moe import FusedMoEWithLoRA from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA from vllm.lora.layers.row_parallel_linear import ( @@ -36,4 +37,5 @@ "RowParallelLinearWithShardedLoRA", "ReplicatedLinearWithLoRA", "LoRAMapping", + "FusedMoEWithLoRA", ] diff --git a/vllm/lora/layers/base.py b/vllm/lora/layers/base.py index 753dc268a2ff..0c7e80684889 100644 --- a/vllm/lora/layers/base.py +++ b/vllm/lora/layers/base.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING import torch import torch.nn as nn @@ -15,14 +15,14 @@ class BaseLayerWithLoRA(nn.Module): def slice_lora_a( - self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]] - ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: + self, lora_a: torch.Tensor | list[torch.Tensor | None] + ) -> torch.Tensor | list[torch.Tensor | None]: """Slice lora a if splitting for tensor parallelism.""" ... def slice_lora_b( - self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]] - ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: + self, lora_b: torch.Tensor | list[torch.Tensor | None] + ) -> torch.Tensor | list[torch.Tensor | None]: """Slice lora b if splitting with tensor parallelism.""" ... @@ -30,7 +30,7 @@ def create_lora_weights( self, max_loras: int, lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, + model_config: PretrainedConfig | None = None, ) -> None: """Initializes lora matrices.""" ... @@ -44,8 +44,7 @@ def set_lora( index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, + embeddings_tensor: torch.Tensor | None, ): """Overwrites lora tensors at index.""" ... @@ -62,7 +61,7 @@ def can_replace_layer( source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: """Returns True if the layer can be replaced by this LoRA layer.""" raise NotImplementedError diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py index d2f017c19ccd..d619a0edc124 100644 --- a/vllm/lora/layers/base_linear.py +++ b/vllm/lora/layers/base_linear.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, cast import torch from transformers import PretrainedConfig @@ -29,7 +28,6 @@ def __init__(self, base_layer: LinearBase): self.tp_size = self.base_layer.tp_size self.tp_rank = self.base_layer.tp_rank self.device = _get_lora_device(self.base_layer) - self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None self.output_slices: tuple[int, ...] self.output_size: int self.n_slices: int @@ -38,7 +36,7 @@ def create_lora_weights( self, max_loras: int, lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, + model_config: PretrainedConfig | None = None, ) -> None: self.lora_config = lora_config # @@ -86,38 +84,19 @@ def create_lora_weights( ) for _ in range(self.n_slices) ) - if lora_config.bias_enabled: - lora_bias_out_size = lora_b_out_size - self.lora_bias_stacked = tuple( - torch.zeros( - max_loras, - 1, - lora_bias_out_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) - for _ in range(self.n_slices) - ) self.output_slices = (self.lora_b_stacked[0].shape[2],) def reset_lora(self, index: int): for s_index in range(self.n_slices): self.lora_a_stacked[s_index][index] = 0 self.lora_b_stacked[s_index][index] = 0 - if self.lora_config.bias_enabled: - # Make mypy happy - self.lora_bias_stacked = cast( - tuple[torch.Tensor, ...], self.lora_bias_stacked - ) - self.lora_bias_stacked[s_index][index] = 0 def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - lora_bias: Optional[torch.Tensor] = None, + embeddings_tensor: torch.Tensor | None, ): # Except for QKVParallelLinearWithLoRA and # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers @@ -131,8 +110,6 @@ def set_lora( if self.tp_size > 1: lora_a = self.slice_lora_a(lora_a) lora_b = self.slice_lora_b(lora_b) - if lora_bias is not None: - lora_bias = self.slice_bias(lora_bias) self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_( lora_a, non_blocking=True @@ -140,18 +117,8 @@ def set_lora( self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_( lora_b, non_blocking=True ) - if lora_bias is not None: - self.lora_bias_stacked = cast( - tuple[torch.Tensor, ...], self.lora_bias_stacked - ) - assert len(self.lora_bias_stacked) - self.lora_bias_stacked[0][index, 0, : lora_bias.shape[0]].copy_( - lora_bias, non_blocking=True - ) - def apply( - self, x: torch.Tensor, bias: Optional[torch.Tensor] = None - ) -> torch.Tensor: + def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) # In transformers backend, x and output have extra batch dimension like @@ -161,14 +128,8 @@ def apply( output = output.flatten(0, 1) x = x.flatten(0, 1) - lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_lora_linear( - output, - x, - self.lora_a_stacked, - self.lora_b_stacked, - self.lora_bias_stacked, - 1.0, - self.output_slices, + lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_linear( + output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices ) if not current_platform.can_update_inplace(): output = lora_output @@ -196,7 +157,7 @@ def weight(self) -> torch.Tensor: raise ValueError(f"Unsupported base layer: {self.base_layer}") @property - def bias(self) -> Optional[torch.Tensor]: + def bias(self) -> torch.Tensor | None: if hasattr(self.base_layer, "bias"): return self.base_layer.bias else: diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index 011d38157456..637ded9b2a0f 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union, cast import torch import torch.nn as nn @@ -32,8 +31,6 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"): == len(layer.lora_b_stacked) == len(layer.output_slices) ) - if layer.lora_bias_stacked is not None: - assert layer.n_slices == len(layer.lora_bias_stacked) output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) @@ -48,7 +45,7 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"): device=x.device, ) - shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink( + shrunk_buffers: torch.Tensor | None = layer.punica_wrapper.add_shrink( buffers, x, layer.lora_a_stacked, 1.0 ) @@ -57,11 +54,10 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"): buffers = tensor_model_parallel_all_gather(buffers) - lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand( + lora_output: torch.Tensor | None = layer.punica_wrapper.add_expand( output, buffers, layer.lora_b_stacked, - layer.lora_bias_stacked, layer.output_slices, offset_start=0, add_input=True, @@ -122,19 +118,9 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: lora_b = lora_b[start_idx:end_idx, :] return lora_b - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - # TODO: Fix the slicing logic of bias. - if bias is None: - return bias - shard_size = self.output_size - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - bias = bias[start_idx:end_idx] - return bias - def forward( self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]: """Forward of ColumnParallelLinear Args: @@ -167,7 +153,7 @@ def can_replace_layer( source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: return type(source_layer) is ColumnParallelLinear or ( type(source_layer) is MergedColumnParallelLinear @@ -185,7 +171,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): """ def __init__( - self, base_layer: Union[MergedColumnParallelLinear, QKVParallelLinear] + self, base_layer: MergedColumnParallelLinear | QKVParallelLinear ) -> None: super().__init__(base_layer) # There are two LoRA layers @@ -202,7 +188,7 @@ def create_lora_weights( self, max_loras: int, lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, + model_config: PretrainedConfig | None = None, ) -> None: """ The main reason for overriding this function is to enhance code @@ -238,26 +224,15 @@ def create_lora_weights( ) for output_size in self.output_slices ) - if lora_config.bias_enabled: - self.lora_bias_stacked = tuple( - torch.zeros( - max_loras, - 1, - output_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) - for output_size in self.output_slices - ) def slice_lora_a( - self, lora_a: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: + self, lora_a: list[torch.Tensor | None] + ) -> list[torch.Tensor | None]: return lora_a def slice_lora_b( - self, lora_b: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: + self, lora_b: list[torch.Tensor | None] + ) -> list[torch.Tensor | None]: sliced_lora_b = [None] * self.n_slices for i, (shard_id, shard_size) in enumerate( zip(self.output_ids, self.output_slices) @@ -268,31 +243,18 @@ def slice_lora_b( ] return sliced_lora_b - def slice_bias( - self, bias: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: - for i, (shard_id, shard_size) in enumerate( - zip(self.output_ids, self.output_slices) - ): - if (bias_i := bias[i]) is not None: - bias[i] = bias_i[shard_size * shard_id : shard_size * (shard_id + 1)] - return bias - def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - lora_bias: Optional[torch.Tensor] = None, + embeddings_tensor: torch.Tensor | None, ): self.reset_lora(index) if self.tp_size > 1: lora_a = self.slice_lora_a(lora_a) lora_b = self.slice_lora_b(lora_b) - if lora_bias is not None: - lora_bias = self.slice_bias(lora_bias) for i in range(self.n_slices): if (lora_a_i := lora_a[i]) is not None: @@ -304,16 +266,6 @@ def set_lora( index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1] ].copy_(lora_b_i, non_blocking=True) - if lora_bias is not None: - self.lora_bias_stacked = cast( - tuple[torch.Tensor, ...], self.lora_bias_stacked - ) - for i in range(self.n_slices): - if (lora_bias_i := lora_bias[i]) is not None: - self.lora_bias_stacked[i][index, 0, : lora_bias_i.shape[0]].copy_( - lora_bias_i, non_blocking=True - ) - @classmethod @_not_fully_sharded_can_replace def can_replace_layer( @@ -321,7 +273,7 @@ def can_replace_layer( source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: return ( type(source_layer) is MergedColumnParallelLinear @@ -380,24 +332,6 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0) return lora_b - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - bias_q = bias[ - self.q_proj_shard_size * self.q_shard_id : self.q_proj_shard_size - * (self.q_shard_id + 1) - ] - k_offset = self.q_proj_total_size - bias_k = bias[ - k_offset + self.kv_proj_shard_size * self.kv_shard_id : k_offset - + self.kv_proj_shard_size * (self.kv_shard_id + 1) - ] - v_offset = k_offset + self.kv_proj_total_size - bias_v = bias[ - v_offset + self.kv_proj_shard_size * self.kv_shard_id : v_offset - + self.kv_proj_shard_size * (self.kv_shard_id + 1) - ] - bias = torch.cat([bias_q, bias_k, bias_v], dim=1) - return bias - @classmethod @_not_fully_sharded_can_replace def can_replace_layer( @@ -405,7 +339,7 @@ def can_replace_layer( source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 1 @@ -448,7 +382,7 @@ def create_lora_weights( self, max_loras: int, lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, + model_config: PretrainedConfig | None = None, ) -> None: """ The main reason for overloading this function is to handle inconsistent @@ -463,7 +397,7 @@ def can_replace_layer( source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3 @@ -491,9 +425,7 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: lora_a = lora_a[start_idx : start_idx + shard_size, :] return lora_a - def apply( - self, x: torch.Tensor, bias: Optional[torch.Tensor] = None - ) -> torch.Tensor: + def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: return _mcp_apply(x, bias, self) @classmethod @@ -503,7 +435,7 @@ def can_replace_layer( source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: # specifying kwargs so they can be easily accessed in decorator return super().can_replace_layer( @@ -524,8 +456,8 @@ class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLo """ def slice_lora_a( - self, lora_a: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: + self, lora_a: list[torch.Tensor | None] + ) -> list[torch.Tensor | None]: # NOTE: lora_a contains 2 subloras, and each sublora could be None. output_shard_size = self.lora_a_stacked[0].shape[2] output_start_idx = self.tp_rank * output_shard_size @@ -539,9 +471,7 @@ def slice_lora_a( ] return lora_a - def apply( - self, x: torch.Tensor, bias: Optional[torch.Tensor] = None - ) -> torch.Tensor: + def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: return _mcp_apply(x, bias, self) @classmethod @@ -551,7 +481,7 @@ def can_replace_layer( source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: # specifying kwargs so they can be easily accessed in decorator return super().can_replace_layer( @@ -577,9 +507,7 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: lora_a = lora_a[start_idx : start_idx + shard_size, :] return lora_a - def apply( - self, x: torch.Tensor, bias: Optional[torch.Tensor] = None - ) -> torch.Tensor: + def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: return _mcp_apply(x, bias, self) @classmethod @@ -589,7 +517,7 @@ def can_replace_layer( source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: # specifying kwargs so they can be easily accessed in decorator return super().can_replace_layer( @@ -610,8 +538,8 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): """ def slice_lora_a( - self, lora_a: list[Union[torch.Tensor, None]] - ) -> list[Union[torch.Tensor, None]]: + self, lora_a: list[torch.Tensor | None] + ) -> list[torch.Tensor | None]: # NOTE: lora_a contains 3 subloras, and each sublora could be None. shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] start_idx = [self.tp_rank * shard_size[i] for i in range(3)] @@ -628,9 +556,7 @@ def slice_lora_a( ] return lora_a - def apply( - self, x: torch.Tensor, bias: Optional[torch.Tensor] = None - ) -> torch.Tensor: + def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: return _mcp_apply(x, bias, self) @classmethod @@ -640,7 +566,7 @@ def can_replace_layer( source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: # specifying kwargs so they can be easily accessed in decorator return super().can_replace_layer( diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py new file mode 100644 index 000000000000..5a9fd35c2907 --- /dev/null +++ b/vllm/lora/layers/fused_moe.py @@ -0,0 +1,411 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm import envs +from vllm.config.lora import LoRAConfig +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.lora.layers.base import BaseLayerWithLoRA +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe.config import ( + FUSED_MOE_UNQUANTIZED_CONFIG, + _get_config_dtype_str, + mxfp4_w4a16_moe_quant_config, +) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + modular_marlin_fused_moe, +) +from vllm.model_executor.layers.fused_moe.fused_moe import ( + modular_triton_fused_moe, + try_get_optimal_moe_config, +) +from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4Config + + +class FusedMoEWithLoRA(BaseLayerWithLoRA): + def __init__(self, base_layer: FusedMoE) -> None: + super().__init__() + self.base_layer = base_layer + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.device = base_layer.w2_weight.device + self._inject_lora_into_fused_moe() + + def _inject_lora_into_fused_moe(self): + moe_state_dict = {} + top_k = self.base_layer.top_k + + if self.base_layer.quant_config is None: + quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + elif not isinstance(self.base_layer.quant_config, Mxfp4Config): + quant_config = self.base_layer.quant_config + else: + quant_config = mxfp4_w4a16_moe_quant_config( + w1_bias=self.base_layer.w13_bias, + w2_bias=self.base_layer.w2_bias, + w1_scale=self.base_layer.w13_weight_scale, + w2_scale=self.base_layer.w2_weight_scale, + ) + + m_fused_moe_fn = ( + modular_triton_fused_moe( + quant_config, shared_experts=self.base_layer.shared_experts + ) + if not quant_config.use_mxfp4_w4a16 + else modular_marlin_fused_moe( + quant_config, shared_experts=self.base_layer.shared_experts + ) + ) + + def fwd_decorator(layer, func): + def wrapper(*args, **kwargs): + moe_state_dict["hidden_states"] = kwargs["hidden_states"] + moe_state_dict["topk_ids"] = kwargs["topk_ids"] + moe_state_dict["topk_weights"] = kwargs["topk_weights"] + moe_state_dict["global_num_experts"] = kwargs["global_num_experts"] + moe_state_dict["expert_map"] = kwargs["expert_map"] + moe_state_dict["apply_router_weight_on_input"] = kwargs[ + "apply_router_weight_on_input" + ] + result = func(*args, **kwargs) + return result + + return wrapper + + def act_decorator(layer, func): + def wrapper(*args, **kwargs): + _, output, input = args + + hidden_states = moe_state_dict["hidden_states"] + topk_weights = moe_state_dict["topk_weights"] + curr_topk_ids = moe_state_dict["topk_ids"] + global_num_experts = moe_state_dict["global_num_experts"] + expert_map = moe_state_dict["expert_map"] + + config_dtype = _get_config_dtype_str( + dtype=hidden_states.dtype, + use_fp8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + ) + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + num_tokens = hidden_states.size(0) + M = min(num_tokens, CHUNK_SIZE) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + layer.w13_weight.size(), + layer.w2_weight.size(), + top_k, + config_dtype, + block_shape=layer.quant_method.moe_quant_config.block_shape, + ) + + max_loras = self.w1_lora_a_stacked.shape[0] + config = get_config_func(M) + ( + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + ) = self.punica_wrapper.moe_lora_align_block_size( + curr_topk_ids, + num_tokens, + config["BLOCK_SIZE_M"], + global_num_experts, + max_loras, + expert_map, + ) + + moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora + moe_state_dict["expert_ids_lora"] = expert_ids_lora + moe_state_dict["num_tokens_post_padded_lora"] = ( + num_tokens_post_padded_lora + ) + + w13_lora_a_stacked = [self.w1_lora_a_stacked, self.w3_lora_a_stacked] + w13_lora_b_stacked = [self.w1_lora_b_stacked, self.w3_lora_b_stacked] + max_lora_rank = self.w1_lora_a_stacked.shape[-2] + expert_ids_lora = expert_ids_lora.view(max_loras, -1) + sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1) + + self.punica_wrapper.add_lora_fused_moe( + input.view(-1, top_k, input.shape[-1]), + hidden_states, + w13_lora_a_stacked, + w13_lora_b_stacked, + topk_weights, + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + max_lora_rank, + top_k, + config, + ) + + result = func(*args, **kwargs) + + moe_state_dict["intermediate_cache2"] = output + return result + + return wrapper + + def moe_sum_decorator(layer, func): + def wrapper(*args, **kwargs): + hidden_states = moe_state_dict["hidden_states"] + topk_weights = moe_state_dict["topk_weights"] + + config_dtype = _get_config_dtype_str( + dtype=hidden_states.dtype, + use_fp8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + ) + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + num_tokens = hidden_states.size(0) + M = min(num_tokens, CHUNK_SIZE) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + layer.w13_weight.size(), + layer.w2_weight.size(), + top_k, + config_dtype, + block_shape=layer.quant_method.moe_quant_config.block_shape, + ) + + config = get_config_func(M) + + sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"] + expert_ids_lora = moe_state_dict["expert_ids_lora"] + num_tokens_post_padded_lora = moe_state_dict[ + "num_tokens_post_padded_lora" + ] + max_loras = self.w1_lora_a_stacked.shape[0] + expert_ids_lora = expert_ids_lora.view(max_loras, -1) + sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1) + intermediate_cache2 = moe_state_dict["intermediate_cache2"] + intermediate_cache3 = args[0] + max_lora_rank = self.w1_lora_a_stacked.shape[-2] + self.punica_wrapper.add_lora_fused_moe( + intermediate_cache3, + intermediate_cache2, + [self.w2_lora_a_stacked], + [self.w2_lora_b_stacked], + topk_weights, + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + max_lora_rank, + top_k, + config, + True, + ) + + result = func(*args, **kwargs) + return result + + return wrapper + + fused_experts = m_fused_moe_fn.fused_experts + + m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward) + fused_experts.activation = act_decorator( + self.base_layer, fused_experts.activation + ) + fused_experts.moe_sum = moe_sum_decorator( + self.base_layer, fused_experts.moe_sum + ) + + self.base_layer.quant_method.old_fused_experts = ( + self.base_layer.quant_method.fused_experts + ) + self.base_layer.quant_method.fused_experts = m_fused_moe_fn + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: PretrainedConfig | None = None, + ) -> None: + """Initializes lora matrices.""" + + assert not self.base_layer.use_ep, ( + "EP support for Fused MoE LoRA is not implemented yet." + ) + + self.w1_lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.global_num_experts, + lora_config.max_lora_rank, + self.base_layer.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.w1_lora_b_stacked = torch.zeros( + ( + max_loras, + self.base_layer.global_num_experts, + self.base_layer.intermediate_size_per_partition, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + + self.w2_lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.global_num_experts, + lora_config.max_lora_rank, + self.base_layer.intermediate_size_per_partition, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.w2_lora_b_stacked = torch.zeros( + ( + max_loras, + self.base_layer.global_num_experts, + self.base_layer.hidden_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + + self.w3_lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.global_num_experts, + lora_config.max_lora_rank, + self.base_layer.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.w3_lora_b_stacked = torch.zeros( + ( + max_loras, + self.base_layer.global_num_experts, + self.base_layer.intermediate_size_per_partition, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + + # They will be used by 'LoRALayerWeights.create_dummy_lora_weights' + # to create a dummy LoRA weights. + self.lora_a_stacked = [] + self.lora_b_stacked = [] + for lora_id in range(max_loras): + for experts_id in range(self.base_layer.global_num_experts): + # gate_proj,down_proj,up_proj + self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id]) + self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id]) + self.lora_a_stacked.append(self.w3_lora_a_stacked[lora_id][experts_id]) + + self.lora_b_stacked.append(self.w1_lora_b_stacked[lora_id][experts_id]) + self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id]) + self.lora_b_stacked.append(self.w3_lora_b_stacked[lora_id][experts_id]) + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + self.w1_lora_a_stacked[index] = 0 + self.w1_lora_b_stacked[index] = 0 + self.w3_lora_a_stacked[index] = 0 + self.w3_lora_b_stacked[index] = 0 + self.w2_lora_a_stacked[index] = 0 + self.w2_lora_b_stacked[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: torch.Tensor | None, + bias: torch.Tensor | None = None, + ): + self.reset_lora(index) + """Overwrites lora tensors at index.""" + for eid in range(len(lora_a) // 3): + w1_lora_a = lora_a[eid * 3] + w2_lora_a = lora_a[eid * 3 + 1] + w3_lora_a = lora_a[eid * 3 + 2] + w1_lora_b = lora_b[eid * 3] + w2_lora_b = lora_b[eid * 3 + 1] + w3_lora_b = lora_b[eid * 3 + 2] + + # Handle the case of adding LoRA to only a subset of experts + if w1_lora_a is None or w2_lora_a is None or w3_lora_a is None: + continue + + if self.tp_size > 1: + shard_size = self.base_layer.intermediate_size_per_partition + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + + w1_lora_b = w1_lora_b[start_idx:end_idx, :] + w3_lora_b = w3_lora_b[start_idx:end_idx, :] + w2_lora_a = w2_lora_a[:, start_idx:end_idx] + + self.w1_lora_a_stacked[ + index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1] + ].copy_(w1_lora_a, non_blocking=True) + + self.w3_lora_a_stacked[ + index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1] + ].copy_(w3_lora_a, non_blocking=True) + + self.w2_lora_b_stacked[ + index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1] + ].copy_(w2_lora_b, non_blocking=True) + + self.w1_lora_b_stacked[ + index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1] + ].copy_(w1_lora_b, non_blocking=True) + self.w3_lora_b_stacked[ + index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1] + ].copy_(w3_lora_b, non_blocking=True) + self.w2_lora_a_stacked[ + index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1] + ].copy_(w2_lora_a, non_blocking=True) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: PretrainedConfig | None, + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + # return type(source_layer) is FusedMoE + return isinstance(source_layer, FusedMoE) + + def forward(self, *args, **kwargs): + return self.base_layer.forward(*args, **kwargs) + + def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs): + return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs) + + @property + def _shared_experts(self): + return self.base_layer._shared_experts + + @property + def quant_method(self): + return self.base_layer.quant_method + + @property + def is_internal_router(self) -> bool: + return self.base_layer.is_internal_router diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py index 4f30c9db4c67..adc5e861f57f 100644 --- a/vllm/lora/layers/logits_processor.py +++ b/vllm/lora/layers/logits_processor.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math -from typing import Optional import torch import torch.nn as nn @@ -41,7 +40,7 @@ def __init__( hidden_size: int, dtype: torch.dtype, device: torch.device, - sharded_to_full_mapping: Optional[list[int]], + sharded_to_full_mapping: list[int] | None, ) -> None: super().__init__() self.base_layer = base_layer @@ -88,7 +87,7 @@ def create_lora_weights( self, max_loras: int, lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, + model_config: PretrainedConfig | None = None, ) -> None: # TODO: Verify if this condition can be further relaxed if 32000 < self.base_layer.vocab_size > 257024: @@ -142,8 +141,7 @@ def set_lora( index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, + embeddings_tensor: torch.Tensor | None, ): self.reset_lora(index) self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_( @@ -163,8 +161,8 @@ def _get_logits( self, hidden_states: torch.Tensor, lm_head: VocabParallelEmbedding, - embedding_bias: Optional[torch.Tensor] = None, - ) -> Optional[torch.Tensor]: + embedding_bias: torch.Tensor | None = None, + ) -> torch.Tensor | None: # Get the logits for the next tokens. logits = lm_head.quant_method.apply(lm_head, hidden_states) if embedding_bias is not None: @@ -228,7 +226,7 @@ def _get_logits( + lora_logits.shape[1], ] = lora_logits - lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_lora_logits( + lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_logits( logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0 ) @@ -248,7 +246,7 @@ def can_replace_layer( source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: # Special handling for the LogitsProcessor. return False diff --git a/vllm/lora/layers/qkv_x_parallel_linear.py b/vllm/lora/layers/qkv_x_parallel_linear.py deleted file mode 100644 index 785cdf38e360..000000000000 --- a/vllm/lora/layers/qkv_x_parallel_linear.py +++ /dev/null @@ -1,8 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .base import BaseLayerWithLoRA - - -# TODO: Implement this -class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA): - pass diff --git a/vllm/lora/layers/replicated_linear.py b/vllm/lora/layers/replicated_linear.py index 18a35cd1e0f2..243736c4ebc6 100644 --- a/vllm/lora/layers/replicated_linear.py +++ b/vllm/lora/layers/replicated_linear.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union import torch import torch.nn as nn @@ -24,7 +23,7 @@ def __init__(self, base_layer: ReplicatedLinear) -> None: def forward( self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]: """Forward of ReplicatedLinearWithLoRA Args: @@ -54,6 +53,18 @@ def can_replace_layer( source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: return type(source_layer) is ReplicatedLinear + + def slice_lora_a( + self, lora_a: torch.Tensor | list[torch.Tensor | None] + ) -> torch.Tensor | list[torch.Tensor | None]: + """Slice lora a if splitting for tensor parallelism.""" + return lora_a + + def slice_lora_b( + self, lora_b: torch.Tensor | list[torch.Tensor | None] + ) -> torch.Tensor | list[torch.Tensor | None]: + """Slice lora b if splitting with tensor parallelism.""" + return lora_b diff --git a/vllm/lora/layers/row_parallel_linear.py b/vllm/lora/layers/row_parallel_linear.py index 738371f22a36..2ef1bd98fc61 100644 --- a/vllm/lora/layers/row_parallel_linear.py +++ b/vllm/lora/layers/row_parallel_linear.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union, cast import torch import torch.nn as nn @@ -39,12 +38,9 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: return lora_b - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - return bias - def forward( self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]: """Forward of RowParallelLinear Args: @@ -96,7 +92,7 @@ def can_replace_layer( source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: return type(source_layer) is RowParallelLinear @@ -123,19 +119,7 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: lora_b = lora_b[start_idx:end_idx, :] return lora_b - def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - if bias is None: - return bias - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked) - shard_size = self.lora_bias_stacked[0].shape[2] - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - bias = bias[start_idx:end_idx] - return bias - - def apply( - self, x: torch.Tensor, bias: Optional[torch.Tensor] = None - ) -> torch.Tensor: + def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x) x = x.view(-1, x.shape[-1]) @@ -146,7 +130,7 @@ def apply( device=x.device, ) - shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink( + shrunk_buffer: torch.Tensor | None = self.punica_wrapper.add_shrink( buffer, x, self.lora_a_stacked, 1.0 ) if not current_platform.can_update_inplace(): @@ -163,11 +147,10 @@ def apply( # NOTE offset are based on the rank. shard_size = self.lora_b_stacked[0].shape[2] offset_start = self.tp_rank * shard_size - lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand( + lora_output: torch.Tensor | None = self.punica_wrapper.add_expand( output, buffer, self.lora_b_stacked, - self.lora_bias_stacked, self.output_slices, offset_start=offset_start, add_input=True, @@ -186,7 +169,7 @@ def can_replace_layer( source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: # specifying kwargs so they can be easily accessed in decorator return super().can_replace_layer( diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py index 42eae1d4e3b0..ca4ad8012e9c 100644 --- a/vllm/lora/layers/vocal_parallel_embedding.py +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch import torch.nn as nn @@ -19,14 +18,14 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: VocabParallelEmbedding) -> None: super().__init__() self.base_layer = base_layer - self.embeddings_slice: Optional[tuple[int, int]] - self.embeddings_weights: Optional[torch.Tensor] + self.embeddings_slice: tuple[int, int] | None + self.embeddings_weights: torch.Tensor | None def create_lora_weights( self, max_loras: int, lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, + model_config: PretrainedConfig | None = None, ) -> None: if self.base_layer.num_added_embeddings_per_partition > 0: # We can start adding lora weights @@ -90,8 +89,7 @@ def set_lora( index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, + embeddings_tensor: torch.Tensor | None, ): self.reset_lora(index) # NOTE self.lora_a_stacked is row-major, and lora_a is col-major, @@ -144,7 +142,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: -1, ) - lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_lora_embedding( + lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_embedding( full_output, full_lora_a_embeddings, self.lora_b_stacked, add_input=True ) @@ -159,7 +157,7 @@ def can_replace_layer( source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: return type(source_layer) is VocabParallelEmbedding diff --git a/vllm/lora/lora_weights.py b/vllm/lora/lora_weights.py index d502c8eb543f..7691481d5039 100644 --- a/vllm/lora/lora_weights.py +++ b/vllm/lora/lora_weights.py @@ -8,7 +8,7 @@ import torch.types from vllm.lora.peft_helper import PEFTHelper -from vllm.utils import is_pin_memory_available +from vllm.utils.platform_utils import is_pin_memory_available class LoRALayerWeights: @@ -21,16 +21,14 @@ def __init__( lora_alpha: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - bias: Optional[torch.Tensor] = None, - embeddings_tensor: Optional[torch.Tensor] = None, - scaling: Optional[float] = None, + embeddings_tensor: torch.Tensor | None = None, + scaling: float | None = None, ) -> None: self.module_name = module_name self.rank = rank self.lora_alpha = lora_alpha self.lora_a = lora_a self.lora_b = lora_b - self.bias = bias self.embeddings_tensor = embeddings_tensor if scaling is None: @@ -69,15 +67,15 @@ def from_config( cls, module_name: str, peft_helper: PEFTHelper, - embeddings_tensor: Optional[torch.Tensor] = None, + embeddings_tensor: torch.Tensor | None = None, ) -> "LoRALayerWeights": + # lora_a and lora_b are set to None for config-based construction return cls( module_name, peft_helper.r, peft_helper.lora_alpha, None, None, - None, embeddings_tensor, peft_helper.vllm_lora_scaling_factor, ) @@ -91,8 +89,7 @@ def create_dummy_lora_weights( rank: int, dtype: torch.dtype, device: torch.types.Device, - embeddings_tensor_dim: Optional[int] = None, - bias_enabled: Optional[bool] = False, + embeddings_tensor_dim: int | None = None, ) -> "LoRALayerWeights": pin_memory = str(device) == "cpu" and is_pin_memory_available() lora_a = torch.zeros( @@ -101,12 +98,6 @@ def create_dummy_lora_weights( lora_b = torch.zeros( [output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory ) - if bias_enabled: - bias = torch.zeros( - [output_dim], dtype=dtype, device=device, pin_memory=pin_memory - ) - else: - bias = None embeddings_tensor = ( torch.rand( @@ -125,7 +116,6 @@ def create_dummy_lora_weights( lora_alpha=1, lora_a=lora_a, lora_b=lora_b, - bias=bias, embeddings_tensor=embeddings_tensor, ) @@ -137,11 +127,10 @@ def __init__( self, module_name: str, rank: int, - lora_alphas: list[Optional[int]], - lora_a: list[Optional[torch.Tensor]], - lora_b: list[Optional[torch.Tensor]], - bias: Optional[list[Optional[torch.Tensor]]] = None, - scaling: Optional[list[float]] = None, + lora_alphas: list[int | None], + lora_a: list[torch.Tensor | None], + lora_b: list[torch.Tensor | None], + scaling: list[float] | None = None, ) -> None: super().__init__( module_name=module_name, @@ -149,7 +138,6 @@ def __init__( lora_alpha=0, lora_a=lora_a, lora_b=lora_b, - bias=bias, scaling=scaling, # type: ignore embeddings_tensor=None, ) @@ -181,7 +169,6 @@ def pack( [lora.lora_alpha if lora is not None else None for lora in loras], [lora.lora_a if lora is not None else None for lora in loras], [lora.lora_b if lora is not None else None for lora in loras], - [lora.bias if lora is not None else None for lora in loras], scaling=[ 1 if lora is not None else None # type: ignore for lora in loras diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 771c8608f4a8..02c252f15bfa 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -3,8 +3,8 @@ import math import os -from collections.abc import Sequence -from typing import Callable, Optional, TypeVar, Union +from collections.abc import Callable +from typing import TypeVar import regex as re import safetensors.torch @@ -13,7 +13,7 @@ from vllm.config.lora import LoRAConfig from vllm.logger import init_logger -from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping +from vllm.lora.layers import BaseLayerWithLoRA, FusedMoEWithLoRA, LoRAMapping from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.peft_helper import PEFTHelper from vllm.lora.punica_wrapper import get_punica_wrapper @@ -23,17 +23,16 @@ get_supported_lora_modules, is_regex_target_modules, parse_fine_tuned_lora_name, + process_packed_modules_mapping, replace_submodule, ) -from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models.interfaces import is_pooling_model from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper -from vllm.model_executor.utils import get_packed_modules_mapping -from vllm.utils import is_pin_memory_available from vllm.utils.cache import LRUCache +from vllm.utils.platform_utils import is_pin_memory_available logger = init_logger(__name__) @@ -45,7 +44,7 @@ def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]): super().__init__(capacity) self.deactivate_fn = deactivate_fn - def _on_remove(self, key: int, value: Optional[T]): + def _on_remove(self, key: int, value: T | None): logger.debug("Removing adapter int id: %d", key) self.deactivate_fn(key) return super()._on_remove(key, value) @@ -60,18 +59,6 @@ def get_lora_id(): return _GLOBAL_LORA_ID -def is_moe_model(model: nn.Module) -> bool: - """Checks if the model contains FusedMoE layers and warns the user.""" - if any(isinstance(module, FusedMoE) for module in model.modules()): - logger.warning_once( - "For MoE models, vLLM currently does not support fused MoE LoRA " - "inference. Please ensure that the loaded LoRA model does not " - "contain expert weights." - ) - return True - return False - - class LoRAModel: """A LoRA fine-tuned model.""" @@ -114,7 +101,7 @@ def extra_vocab_size(self) -> int: else 0 ) - def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]: + def get_lora(self, module_name: str) -> LoRALayerWeights | None: """Get LoRA for a given module by name""" return self.loras.get(module_name, None) @@ -129,18 +116,18 @@ def from_lora_tensors( tensors: dict[str, torch.Tensor], peft_helper: PEFTHelper, device: str = "cuda", - dtype: Optional[torch.dtype] = None, - embeddings: Optional[dict[str, torch.Tensor]] = None, - target_embedding_padding: Optional[int] = None, - embedding_modules: Optional[dict[str, str]] = None, - embedding_padding_modules: Optional[list[str]] = None, - weights_mapper: Optional[WeightsMapper] = None, + dtype: torch.dtype | None = None, + embeddings: dict[str, torch.Tensor] | None = None, + target_embedding_padding: int | None = None, + embedding_modules: dict[str, str] | None = None, + embedding_padding_modules: list[str] | None = None, + weights_mapper: WeightsMapper | None = None, ) -> "LoRAModel": """Create a LoRAModel from a dictionary of tensors.""" pin_memory = str(device) == "cpu" and is_pin_memory_available() loras: dict[str, LoRALayerWeights] = {} for tensor_name, tensor in tensors.items(): - module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( + module_name, is_lora_a = parse_fine_tuned_lora_name( tensor_name, weights_mapper ) if module_name not in loras: @@ -160,13 +147,7 @@ def from_lora_tensors( module_name, peft_helper, lora_embeddings_tensor ) - if is_bias: - loras[module_name].bias = tensor.to(device=device, dtype=dtype) - bias = tensor.to(device=device, dtype=dtype) - if pin_memory: - bias = bias.pin_memory() - loras[module_name].bias = bias - elif is_lora_a: + if is_lora_a: loras[module_name].lora_a = tensor.to(device=device, dtype=dtype) if pin_memory: loras[module_name].lora_a = loras[module_name].lora_a.pin_memory() @@ -198,14 +179,14 @@ def from_local_checkpoint( expected_lora_modules: list[str], peft_helper: PEFTHelper, *, - lora_model_id: Optional[int] = None, + lora_model_id: int | None = None, device: str = "cuda", - dtype: Optional[torch.dtype] = None, - target_embedding_padding: Optional[int] = None, - embedding_modules: Optional[dict[str, str]] = None, - embedding_padding_modules: Optional[list[str]] = None, - weights_mapper: Optional[WeightsMapper] = None, - tensorizer_config_dict: Optional[dict] = None, + dtype: torch.dtype | None = None, + target_embedding_padding: int | None = None, + embedding_modules: dict[str, str] | None = None, + embedding_padding_modules: list[str] | None = None, + weights_mapper: WeightsMapper | None = None, + tensorizer_config_dict: dict | None = None, ) -> "LoRAModel": """Create a LoRAModel from a local checkpoint. @@ -230,16 +211,24 @@ def from_local_checkpoint( ) new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin") tensors: dict[str, torch.Tensor] = {} - unexpected_modules: list[Union[list[str], str]] = [] + unexpected_modules: list[list[str] | str] = [] def check_unexpected_modules(modules: dict): for lora_module in modules.keys(): # noqa - module_name, _, _ = parse_fine_tuned_lora_name( - lora_module, weights_mapper - ) - part_name = module_name.split(".")[-1] - if part_name not in expected_lora_modules: + module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper) + # Handle FSDP file format where experts.base_layer is the + # gate_up_proj and experts is the down_proj + if "base_layer" in lora_module: + continue + # Case for expert lora weights + if ".experts" in module_name: + if not any( + module_name.endswith(ele) for ele in expected_lora_modules + ): + unexpected_modules.append(module_name) + elif module_name.split(".")[-1] not in expected_lora_modules: unexpected_modules.append(module_name) + if unexpected_modules: raise ValueError( f"While loading {lora_dir}, expected" @@ -366,7 +355,7 @@ def __init__( self.max_num_seqs = max_num_seqs assert self.capacity >= self.lora_slots self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 - self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots + self.lora_index_to_id: list[int | None] = [None] * self.lora_slots self.vocab_size = vocab_size self.punica_wrapper = get_punica_wrapper( max_num_batched_tokens, @@ -379,7 +368,7 @@ def __init__( assert self.supported_lora_modules, "No supported LoRA modules found in" f" {self.model.__class__.__name__}." - self.packed_modules_mapping = get_packed_modules_mapping(self.model) + self.packed_modules_mapping = process_packed_modules_mapping(self.model) # Used to indicate whether the model is a multimodal model self.supports_mm: bool = ( supports_multimodal(self.model) @@ -388,11 +377,10 @@ def __init__( and hasattr(self.model, "get_mm_mapping") ) self.is_pooling_model = is_pooling_model(self.model) - self.is_moe_model = is_moe_model(self.model) self.packed_modules: dict[str, list[str]] = {} self.modules: dict[str, BaseLayerWithLoRA] = {} # Dict instead of a set for compatibility with LRUCache. - self._last_mapping: Optional[LoRAMapping] = None + self._last_mapping: LoRAMapping | None = None self._create_lora_modules() self.model.lora_manager = self @@ -438,24 +426,55 @@ def activate_adapter( for module_name, module in self.modules.items(): module_lora = self._get_lora_layer_weights(lora_model, module_name) if module_lora: - module_lora.optimize() - # Bias is not explicitly enabled with the flag enable_lora_bias. - bias = module_lora.bias - if ( - torch.is_tensor(bias) - or (isinstance(bias, Sequence) and any(b is not None for b in bias)) - ) and not self.lora_config.bias_enabled: - module_lora.bias = None - raise ValueError( - f"Adapter bias cannot be used for {module_name}" - " without --enable-lora-bias." + # Note (gnovack) - If MOE lora weights are not split into + # num_experts chunks, we split them here + if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor( + module_lora.lora_a + ): + # Handle FSDP file format where experts.base_layer is the + # gate_up_proj and experts is the down_proj + gate_up_proj_lora = self._get_lora_layer_weights( + lora_model, module_name + ".base_layer" + ) + + assert gate_up_proj_lora is not None + assert module_lora is not None + + down_proj_lora = module_lora + num_experts = module_lora.lora_a.shape[0] // module_lora.rank + + gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0) + up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0) + + gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk( + num_experts, dim=-1 + ) + up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk( + num_experts, dim=-1 ) + + down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0) + down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1) + + lora_a = [] + lora_b = [] + for i in range(num_experts): + lora_a.append(gate_proj_a[i]) + lora_a.append(down_proj_a[i]) + lora_a.append(up_proj_a[i]) + + lora_b.append(gate_proj_b[i]) + lora_b.append(down_proj_b[i]) + lora_b.append(up_proj_b[i]) + + module_lora.lora_a = lora_a + module_lora.lora_b = lora_b + module.set_lora( index, module_lora.lora_a, module_lora.lora_b, module_lora.embeddings_tensor, - module_lora.bias, ) else: module.reset_lora(index) @@ -506,6 +525,7 @@ def _parent_module(module_name: str) -> str: for module_name, module in self.model.named_modules(remove_duplicate=False): if isinstance(module, PPMissingLayer): continue + if not self._match_target_modules(module_name): continue # A temporary approach for multimodal models to support LoRA @@ -569,19 +589,21 @@ def _parent_module(module_name: str) -> str: new_module.set_mapping(self.punica_wrapper) def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): - assert isinstance(module, BaseLayerWithLoRA) + assert isinstance(module, BaseLayerWithLoRA), ( + f"Module {module_name} must be a BaseLayerWithLoRA instance," + ) + f" got {type(module)}" self.modules[module_name] = module def create_dummy_lora( self, lora_id: int, rank: int, - embedding_modules: Optional[dict[str, str]] = None, + embedding_modules: dict[str, str] | None = None, ) -> LoRAModel: """Create zero-initialized LoRAModel for warmup.""" model = LoRAModel(lora_id, rank, {}) for module_name, module in self.model.named_modules(): - bias_enabled = self.lora_config.bias_enabled if ( not self._match_target_modules(module_name) or not isinstance(module, BaseLayerWithLoRA) @@ -616,7 +638,6 @@ def create_dummy_lora( module.lora_a_stacked[0].dtype, "cpu", embeddings_tensor_dim=embeddings_tensor_dim, - bias_enabled=bias_enabled, ) else: lora = LoRALayerWeights.create_dummy_lora_weights( @@ -626,12 +647,11 @@ def create_dummy_lora( rank, module.lora_a_stacked[0].dtype, "cpu", - bias_enabled=bias_enabled, ) else: parts = module_name.split(".") replacements = self.packed_modules_mapping[parts[-1]] - subloras: list[Optional[LoRALayerWeights]] = [] + subloras: list[LoRALayerWeights | None] = [] for i, r in enumerate(replacements): lora = LoRALayerWeights.create_dummy_lora_weights( module_name + "." + r, @@ -640,7 +660,6 @@ def create_dummy_lora( rank, module.lora_a_stacked[i].dtype, "cpu", - bias_enabled=bias_enabled, ) subloras.append(lora) lora = PackedLoRALayerWeights.pack(subloras) @@ -683,7 +702,7 @@ def _register_packed_modules(self, module_full_name: str) -> None: def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: for module_name, new_module_names in self.packed_modules.items(): - replacement_loras: list[Optional[LoRALayerWeights]] = [] + replacement_loras: list[LoRALayerWeights | None] = [] replaced_module: set[str] = set() has_replacement = False for r in new_module_names: @@ -712,7 +731,7 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: def _get_lora_layer_weights( self, lora_model: LoRAModel, module_name: str - ) -> Optional[LoRALayerWeights]: + ) -> LoRALayerWeights | None: org_module_name = module_name if self.is_pooling_model and not lora_model.check_lora_name(module_name): # If it's a pool model, and the layer name is not found, @@ -757,7 +776,7 @@ def remove_adapter(self, adapter_id: int) -> bool: def list_adapters(self) -> dict[int, LoRAModel]: return dict(self._registered_adapters) - def get_adapter(self, adapter_id: int) -> Optional[LoRAModel]: + def get_adapter(self, adapter_id: int) -> LoRAModel | None: return self._registered_adapters.get(adapter_id) diff --git a/vllm/lora/ops/triton_ops/README_TUNING.md b/vllm/lora/ops/triton_ops/README_TUNING.md new file mode 100644 index 000000000000..fda95ea71891 --- /dev/null +++ b/vllm/lora/ops/triton_ops/README_TUNING.md @@ -0,0 +1,51 @@ +# Multi-LoRA Tuning + +**Note**: The LoRA configuration folder should be specified by exporting `VLLM_TUNED_CONFIG_FOLDER=/path/to/configs`. Without this, the shrink/expand kernels will use default configurations. + +## Tuning Process + +Multi-lora shrink/expand Triton kernel tuning follows a similar methodology from [Triton MoE tuning](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py). + +**Step 1** +Define the searching space. An example searching space: + +```python +block_m_range = [16, 32, 64, 128, 256] +block_n_range = [32, 64, 128, 256] +block_k_range = [32, 64, 128, 256] +num_warps_range = [4, 8] +num_stage_range = [2, 3, 4, 5] +num_ctas_range = [1] +split_k_range = [4, 8, 16, 32, 64] +``` + +**Step 2** +Get all hidden_state sizes and num_slices that the target model uses for a specific TP size. + +For example, we can aquire those info by simply checking [add_lora_linear](https://github.com/li2haipeng/vllm/blob/multi_lora_v01011/vllm/lora/punica_wrapper/punica_gpu.py#L192): + +```python +print(f"x_shape: {x.view(-1, x.shape[-1]).shape}") +print(f"num_sclises: {len(output_slices)}") +for i in range(len(output_slices)): + print(f"a{i} shape: {lora_a_stacked[i].shape}") + print(f"b{i} shape: {lora_b_stacked[i].shape}") +print("y_shape", y.shape) +``` + +**Step 3** +Benchmark the shrink/expand kernel runtime with different kernel configurations generated from the pre-defined search space by performing a grid search to find the optimal kernel configuration. vLLM's [benchmark_lora.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_lora.py) can be used to search for configurations for different shapes. + +## Config Files + +### File Name + +For `shrink`, the config file is named as `{gpu_name}_SHRINK.json`, e.g. `NVIDIA_H200_SHRINK.json`. + +For `expand`, the config fileis named as `{gpu_name}_EXPAND_{add_input}.json`, e.g. `NVIDIA_H200_EXPAND_TRUE.json`. + +The `gpu_name` can be automatically detected by calling `torch.cuda.get_device_name()` + +### Json Structure + +Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n]` diff --git a/vllm/lora/ops/triton_ops/__init__.py b/vllm/lora/ops/triton_ops/__init__.py index 805de4b6f657..436ea4ed00c8 100644 --- a/vllm/lora/ops/triton_ops/__init__.py +++ b/vllm/lora/ops/triton_ops/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.lora.ops.triton_ops.fused_moe_lora_op import fused_moe_lora from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink @@ -9,4 +10,5 @@ "lora_expand", "lora_shrink", "LoRAKernelMeta", + "fused_moe_lora", ] diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py new file mode 100644 index 000000000000..94935d8dfe86 --- /dev/null +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -0,0 +1,350 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import triton +import triton.language as tl + +from vllm.utils.torch_utils import direct_register_custom_op + +_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {} + + +def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device): + """ + `_LORA_PTR_DICT` collects the required information during `profile_run`, + After this, it remains constant and subsequent usage is through LUT. + Refer to: + https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py + """ + key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights) + + if (ptr_tensor := _LORA_PTR_DICT.get(key)) is not None: + return ptr_tensor + + tensor_ptrs = [] + for lora_weight in lora_weights: + tensor_ptrs.append(lora_weight.data_ptr()) + ptr_tensor = torch.tensor(tensor_ptrs, device=device) + + _LORA_PTR_DICT[key] = ptr_tensor + return _LORA_PTR_DICT.get(key) + + +@triton.jit +def _fused_moe_lora_kernel( + a_ptr, + b_ptr, + c_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + num_experts, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_bl, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_tl, + stride_el, + # Meta-parameters + num_slice_a: tl.constexpr, + num_slice_c: tl.constexpr, + slice_a_size: tl.constexpr, + slice_c_size: tl.constexpr, + top_k: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + slice_id = tl.program_id(axis=1) + lora_idx = tl.program_id(axis=2) + max_loras = tl.num_programs(axis=2) + + # calculate pid_m,pid_n + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + + # get the expert_id to process curr shard + ind = lora_idx * stride_el + pid_m + expert_id = tl.load(expert_ids_ptr + ind) + if expert_id == -1: + return + + # get a_ptr,b_ptr,c_ptr + cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size + cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(tl.bfloat16)) + cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + token_ind = stride_tl * lora_idx + offs_token_id + offs_token = tl.load( + sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0.0 + ) + token_mask = offs_token < num_valid_tokens + + # get a_ptrs,b_ptrs + a_ptrs = cur_a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + b_ptrs = ( + cur_b_ptr + + lora_idx * stride_bl + + expert_id * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) + + # accumulator + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(tl.bfloat16) + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@torch.inference_mode() +def _fused_moe_lora( + output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) + qcurr_hidden_states: torch.Tensor, # (num_tokens, K,) + lora_a_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, max_lora_rank, K,),...] + lora_b_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, N, max_lora_rank,),...] + topk_weights: torch.Tensor, # (num_tokens, top_k_num) + sorted_token_ids: torch.Tensor, # (max_loras, _) + expert_ids: torch.Tensor, # (max_loras, _ ,) + num_tokens_post_padded: torch.Tensor, # (max_loras, ) + max_lora_rank: int, + top_k_num: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + mul_routed_weight: bool = False, +) -> None: + assert len(lora_a_stacked) == len(lora_b_stacked) > 0 + assert ( + sorted_token_ids.dim() + == expert_ids.dim() + == topk_weights.dim() + == qcurr_hidden_states.dim() + == 2 + ) + assert ( + sorted_token_ids.shape[0] + == expert_ids.shape[0] + == num_tokens_post_padded.shape[0] + ) + assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1] + assert output.shape[0] == topk_weights.shape[0] + assert top_k_num == topk_weights.shape[1] + + device = qcurr_hidden_states.device + num_slices = len(lora_a_stacked) + + config = { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + } + + w1_lora_a_stacked = lora_a_stacked[0] + w1_lora_b_stacked = lora_b_stacked[0] + num_experts = lora_a_stacked[0].shape[1] + + N = max_lora_rank + M = topk_weights.shape[0] + EM = sorted_token_ids.shape[1] + K = qcurr_hidden_states.shape[1] + num_tokens = M * top_k_num + w1_output_dim_size = w1_lora_b_stacked.shape[2] + + lora_intermediate_cache1 = torch.zeros( + (num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)), + dtype=torch.bfloat16, + device=device, + ) + + # slices + a_intermediate_size = num_slices * M * top_k_num * max_lora_rank + a_intermediate_cache1 = lora_intermediate_cache1[:a_intermediate_size].view( + num_slices, M, top_k_num, max_lora_rank + ) + b_intermediate_cache1 = lora_intermediate_cache1[a_intermediate_size:].view( + num_slices, M, top_k_num, w1_output_dim_size + ) + + b_ptr = _get_ptr(lora_a_stacked, device) + + grid = lambda META: ( + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + len(lora_a_stacked), + lora_a_stacked[0].shape[0], + ) + + _fused_moe_lora_kernel[grid]( + qcurr_hidden_states, + b_ptr, + a_intermediate_cache1, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + N, + K, + EM, + num_tokens, + num_experts, + qcurr_hidden_states.stride(0), + qcurr_hidden_states.stride(1), + w1_lora_a_stacked.stride(0), + w1_lora_a_stacked.stride(1), + w1_lora_a_stacked.stride(3), + w1_lora_a_stacked.stride(2), + a_intermediate_cache1.stride(2), + a_intermediate_cache1.stride(3), + sorted_token_ids.stride(0), + expert_ids.stride(0), + num_slice_a=1, + num_slice_c=num_slices, + slice_a_size=qcurr_hidden_states.numel(), + slice_c_size=a_intermediate_cache1.numel() // num_slices, + top_k=1 if mul_routed_weight else top_k_num, + MUL_ROUTED_WEIGHT=False, + **config, + ) + + b_ptr = _get_ptr(lora_b_stacked, device) + K = max_lora_rank + N = w1_output_dim_size + + # a_intermediate_cache1 = a_intermediate_cache1.view( + # M, -1, a_intermediate_cache1.shape[3] + # ) + + a_intermediate_cache1 = a_intermediate_cache1.view( + -1, a_intermediate_cache1.shape[3] + ) + + grid = lambda META: ( + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + len(lora_b_stacked), + lora_b_stacked[0].shape[0], + ) + _fused_moe_lora_kernel[grid]( + a_intermediate_cache1, + b_ptr, + b_intermediate_cache1, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + N, + K, + EM, + num_tokens, + num_experts, + a_intermediate_cache1.stride(0), + a_intermediate_cache1.stride(1), + w1_lora_b_stacked.stride(0), + w1_lora_b_stacked.stride(1), + w1_lora_b_stacked.stride(3), + w1_lora_b_stacked.stride(2), + b_intermediate_cache1.stride(2), + b_intermediate_cache1.stride(3), + sorted_token_ids.stride(0), + expert_ids.stride(0), + num_slice_a=num_slices, + num_slice_c=num_slices, + slice_a_size=a_intermediate_cache1.numel() // num_slices, + slice_c_size=b_intermediate_cache1.numel() // num_slices, + top_k=1, + MUL_ROUTED_WEIGHT=mul_routed_weight, + **config, + ) + for i in range(num_slices): + output[:, :, i * N : (i + 1) * N] += b_intermediate_cache1[i] + + +def _fused_moe_lora_fake( + output: torch.Tensor, + qcurr_hidden_states: torch.Tensor, + lora_a_stacked: list[torch.Tensor], + lora_b_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + max_lora_rank: int, + top_k_num: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + mul_routed_weight: bool = False, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="fused_moe_lora", + op_func=_fused_moe_lora, + mutates_args=["output"], + fake_impl=_fused_moe_lora_fake, + ) + fused_moe_lora = torch.ops.vllm.fused_moe_lora + +except AttributeError: + fused_moe_lora = _fused_moe_lora diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index a7a552b9903d..fd4c1364de7e 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -10,9 +10,9 @@ import torch from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel -from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr +from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op @triton.jit @@ -201,12 +201,21 @@ def _lora_expand( NUM_SLICES = len(lora_b_weights) # Triton kernel configs. - BLOCK_M = 64 - BLOCK_N = 128 - BLOCK_K = 16 - NUM_WARPS = 4 - NUM_CTAS = 1 - NUM_STAGES = 2 + kernel_config = get_lora_op_configs( + op_type="expand", + max_loras=MAX_LORAS, + batch=M, + hidden_size=MAX_N, + rank=K, + num_slices=NUM_SLICES, + add_inputs=add_inputs, + ) + BLOCK_M = kernel_config["block_m"] + BLOCK_N = kernel_config["block_n"] + BLOCK_K = kernel_config["block_k"] + NUM_WARPS = kernel_config["num_warps"] + NUM_CTAS = kernel_config["num_ctas"] + NUM_STAGES = kernel_config["num_stages"] EVEN_K = K % BLOCK_K == 0 # type: ignore diff --git a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py index df343305d710..c3bef7680dd0 100644 --- a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py +++ b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py @@ -5,7 +5,6 @@ """ from dataclasses import dataclass -from typing import Union import torch @@ -31,7 +30,7 @@ class LoRAKernelMeta: @staticmethod def make( - max_loras: int, max_num_tokens: int, device: Union[torch.device, str] + max_loras: int, max_num_tokens: int, device: torch.device | str ) -> "LoRAKernelMeta": token_lora_mapping = torch.empty( max_num_tokens, dtype=torch.int32, device=device diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 1e7e43e30de7..8d126197f83e 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -10,9 +10,9 @@ import torch from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel -from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr +from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op @triton.jit @@ -169,6 +169,8 @@ def _lora_shrink( assert lora_ids.size(0) == num_tokens_per_lora.size(0) assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 + output_tensor.zero_() + (lora_ptr_tensor, lora_strides_d0, lora_strides_d1, lora_strides_d2) = ( _get_lora_a_ptr(lora_a_weights, inputs.device) ) @@ -177,14 +179,21 @@ def _lora_shrink( MAX_LORAS = lora_ids.size(0) # Triton kernel configs - BLOCK_M = 32 - BLOCK_N = 16 - BLOCK_K = 256 if M < 128 else 32 - SPLIT_K = 64 if M < 128 else 8 - NUM_WARPS = 4 - NUM_CTAS = 1 - NUM_STAGES = 2 - + kernel_config = get_lora_op_configs( + "shrink", + max_loras=MAX_LORAS, + batch=M, + hidden_size=K, + rank=N, + num_slices=NUM_SLICES, + ) + BLOCK_M = kernel_config["block_m"] + BLOCK_N = kernel_config["block_n"] + BLOCK_K = kernel_config["block_k"] + SPLIT_K = kernel_config["split_k"] + NUM_WARPS = kernel_config["num_warps"] + NUM_STAGES = kernel_config["num_stages"] + NUM_CTAS = kernel_config["num_ctas"] EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore # TODO (varun): This grid formulation maximizes parallelization at the diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index 3a3e8fc8931e..9ffb6dc3d85e 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -1,8 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +import json +from pathlib import Path +from typing import Any + import torch +from vllm import envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + _LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {} _LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {} @@ -133,3 +143,108 @@ def _get_lora_b_ptr( MAX_N, ) return _LORA_B_PTR_DICT.get(key) + + +@functools.lru_cache +def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None: + user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER + if user_defined_config_folder is not None: + gpu_name = torch.cuda.get_device_name() + gpu_name = gpu_name.replace(" ", "_") + gpu_name = gpu_name.replace("-", "_") + + config_fname = None + if op_type == "shrink": + config_fname = f"{gpu_name}_{op_type.upper()}.json" + else: + assert op_type == "expand" + config_fname = ( + f"{gpu_name}_{op_type.upper()}_{str(add_inputs).upper()}.json" + ) + + config_path = Path(f"{user_defined_config_folder}/{config_fname}") + if not config_path.exists(): + logger.warning_once(f"No LoRA kernel configs founded in {config_path}") + return None + + # Load json + logger.info_once(f"Using tuned LoRA kernel configs from {config_path}.") + with open(str(config_path)) as f: + config_data = json.load(f) + else: + config_data = None + + return config_data + + +@functools.lru_cache +def get_lora_op_configs( + op_type: str, + max_loras: int, + batch: int, + hidden_size: int, + rank: int, + num_slices: int, + add_inputs: bool | None = None, +) -> dict[str, int | None]: + assert op_type in ["shrink", "expand"] + + # default config + default = {} + if op_type == "shrink": + default = { + "block_m": 32, + "block_n": 16, + "block_k": 256 if batch < 128 else 32, + "split_k": 64 if batch < 128 else 8, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 2, + "max_nreg": None, + } + else: + default = { + "block_m": 64, + "block_n": 128, + "block_k": 16, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 2, + "max_nreg": None, + } + m = batch + + k, n = (hidden_size, rank) if op_type == "shrink" else (rank, hidden_size) + + config_data: Any + config_data = load_lora_op_config(op_type, add_inputs) + if not config_data: + logger.warning_once("Using default LoRA kernel configs") + return default + + # config is structured as config_data[max_loras][num_slices][m][k][n] = {} + # slice by max_loras + config_data = ( + config_data.get(str(max_loras)) + or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - max_loras))] + ) + # slice by num_slices + config_data = config_data[str(num_slices)] + # slice by m + config_data = ( + config_data.get(str(m)) + or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - m))] + ) + # slice by k + config_data = ( + config_data.get(str(k)) + or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - k))] + ) + # slice by n + config_data = ( + config_data.get(str(n)) + or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - n))] + ) + + assert config_data is not None + return config_data diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index 48412eab92d8..975c3d8fc0a7 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -7,7 +7,7 @@ import math import os from dataclasses import MISSING, dataclass, field, fields -from typing import Literal, Optional, Union +from typing import Literal from vllm.config.lora import LoRAConfig from vllm.logger import init_logger @@ -27,17 +27,17 @@ class PEFTHelper: # Required fields r: int lora_alpha: int - target_modules: Union[list[str], str] + target_modules: list[str] | str - bias: Literal["none", "all", "lora_only"] = field(default="none") - modules_to_save: Optional[list[str]] = field(default=None) + bias: Literal["none"] = field(default="none") + modules_to_save: list[str] | None = field(default=None) # True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732) use_rslora: bool = field(default=False) # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353) use_dora: bool = field(default=False) # Extra vllm field, start with 'vllm_' to avoid conflict vllm_lora_scaling_factor: float = field(default=1.0) - vllm_max_position_embeddings: Optional[int] = field(default=False) + vllm_max_position_embeddings: int | None = field(default=False) def _validate_features(self) -> list[str]: """ @@ -81,8 +81,8 @@ def from_dict(cls, config_dict: dict) -> "PEFTHelper": def from_local_dir( cls, lora_path: str, - max_position_embeddings: Optional[int], - tensorizer_config_dict: Optional[dict] = None, + max_position_embeddings: int | None, + tensorizer_config_dict: dict | None = None, ) -> "PEFTHelper": lora_config_path = os.path.join(lora_path, "adapter_config.json") @@ -122,7 +122,7 @@ def validate_legal(self, lora_config: LoRAConfig) -> None: f"LoRA rank {self.r} is greater than max_lora_rank" f" {lora_config.max_lora_rank}." ) - if self.bias != "none" and not lora_config.bias_enabled: - error_msg.append("Adapter bias cannot be used without bias_enabled.") + if self.bias != "none": + error_msg.append("Adapter bias is not supported.") if error_msg: raise ValueError(f"{' '.join(error_msg)}") diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 770c3cf7b073..5b4a18cf4789 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -8,7 +8,7 @@ """ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING import torch @@ -28,7 +28,7 @@ class PunicaWrapperABC(ABC): def update_metadata( self, mapping: "LoRAMapping", - lora_index_to_id: list[Optional[int]], + lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, extra_vocab_size: int, @@ -42,12 +42,12 @@ def update_metadata( @abstractmethod def add_shrink( self, - y: Union[tuple[torch.Tensor, ...], torch.Tensor], + y: tuple[torch.Tensor, ...] | torch.Tensor, x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], scale: float, **kwargs, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """ Performs GEMM for multiple slices of lora_a. """ @@ -58,16 +58,15 @@ def add_shrink( def add_expand( self, y: torch.Tensor, - x: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: tuple[torch.Tensor, ...] | torch.Tensor, lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """ - Performs GEMM and bias addition for multiple slices of lora_b. + Performs GEMM for multiple slices of lora_b. """ raise NotImplementedError @@ -79,7 +78,7 @@ def add_lora_embedding( lora_b_stacked: torch.Tensor, add_inputs: bool = True, **kwargs, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA, and this layer only requires the expand operation. @@ -93,13 +92,12 @@ def add_lora_linear( x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, output_slices: tuple[int, ...], *, - buffer: Optional[tuple[torch.Tensor, ...]] = None, + buffer: tuple[torch.Tensor, ...] | None = None, **kwargs, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """ Applicable to linear-related lora. """ @@ -115,9 +113,9 @@ def add_lora_logits( lora_b_stacked: torch.Tensor, scale, *, - buffer: Optional[torch.Tensor] = None, + buffer: torch.Tensor | None = None, **kwargs, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """ Applies lora specifically for LogitsProcessorWithLoRA. """ @@ -135,7 +133,7 @@ def __init__( self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str], + device: torch.device | str, **kwargs, ): self._token_lora_indices = torch.empty( @@ -154,7 +152,7 @@ def __init__( # 4 is the number of indices tensors. # base_indices, sampler_indices, sampler_indices_padded, # embeddings_indices - self.indices_len: list[Optional[int]] = [None] * 4 + self.indices_len: list[int | None] = [None] * 4 # these attributes are the information required for sgmv kernel self._seq_start_locs = torch.empty(max_batches, dtype=torch.long, device=device) self._seq_lengths = torch.empty(max_batches, dtype=torch.long, device=device) @@ -171,7 +169,7 @@ def __init__( def _update_base_metadata( self, mapping: "LoRAMapping", - lora_index_to_id: list[Optional[int]], + lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, extra_vocab_size: int, @@ -222,38 +220,6 @@ def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None: self.token_nums = token_nums self.no_lora = no_lora - def _apply_bias( - self, - indices: torch.Tensor, - output: torch.Tensor, - output_slices: tuple[int, ...], - lora_bias_stacked: tuple[Optional[torch.Tensor], ...], - ): - """Applies bias to output - - Input shapes: - lora_bias_stacked: 3 element tuple of (num_loras, output_dim) - indices: (batch_size) - output: (batch_size, q_slice_size + 2*kv_slice_size) - output_slices: n-1 element tuple of (slice_size...), - where n is number of slices - """ - org_output = output - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - - offset_left = 0 - for slice_idx, slice in enumerate(output_slices): - bias = lora_bias_stacked[slice_idx] - if bias is not None: - bias = bias.view(-1, bias.shape[-1]) - bias = bias[indices] - bias[indices == -1] = 0 - output[:, offset_left : offset_left + slice] += bias - offset_left += slice - - return output.view_as(org_output) - @property def prefill_metadata( self, @@ -316,7 +282,7 @@ def embeddings_indices(self) -> torch.Tensor: def update_metadata( self, mapping: "LoRAMapping", - lora_index_to_id: list[Optional[int]], + lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, extra_vocab_size: int, @@ -336,12 +302,12 @@ def update_metadata( @abstractmethod def add_shrink( self, - y: Union[tuple[torch.Tensor, ...], torch.Tensor], + y: tuple[torch.Tensor, ...] | torch.Tensor, x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], scale: float, **kwargs, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """ Performs GEMM for multiple slices of lora_a. @@ -363,31 +329,27 @@ def add_shrink( def add_expand( self, y: torch.Tensor, - x: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: tuple[torch.Tensor, ...] | torch.Tensor, lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """ - Performs GEMM and bias addition for multiple slices of lora_b. + Performs GEMM for multiple slices of lora_b. Semantics: offset = offset_start for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] offset += slice Args: y (torch.Tensor): Output tensor. x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): - bias's weight output_slices (tuple[int, ...]): Every slice's size offset_start (int): The starting position of y, defaults to 0 add_inputs (bool): Defaults to True. @@ -404,7 +366,7 @@ def add_lora_embedding( lora_b_stacked: torch.Tensor, add_inputs: bool = True, **kwargs, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. and this layer only requires the expand operation. @@ -427,13 +389,12 @@ def add_lora_linear( x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, output_slices: tuple[int, ...], *, - buffer: Optional[tuple[torch.Tensor, ...]] = None, + buffer: tuple[torch.Tensor, ...] | None = None, **kwargs, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """ Applicable to linear-related lora. @@ -444,14 +405,13 @@ def add_lora_linear( @ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :] * scale - ).squeeze(0)+lora_bias_stacked[i] + ).squeeze(0) Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. output_slices (tuple[int, ...]): Every slice's size. buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. @@ -468,9 +428,9 @@ def add_lora_logits( lora_b_stacked: torch.Tensor, scale, *, - buffer: Optional[torch.Tensor] = None, + buffer: torch.Tensor | None = None, **kwargs, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """ Applies lora specifically for LogitsProcessorWithLoRA. @@ -488,3 +448,42 @@ def add_lora_logits( """ # TODO: implement it based on torch ops raise NotImplementedError + + def moe_lora_align_block_size( + self, + topk_ids: torch.Tensor, + num_tokens: int, + block_size: int, + num_experts: int, + max_loras: int, + expert_map: torch.Tensor | None = None, + pad_sorted_ids: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns tokens and experts into block-sized chunks for LoRA-based + mixture-of-experts (MoE) execution. + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + def add_lora_fused_moe( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: list[torch.Tensor], + lora_b_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + max_lora_rank: int, + top_k_num: int, + config, + mul_routed_weight=False, + ): + """ + Performs a fused forward computation for LoRA of + Mixture-of-Experts (MoE) layer. + """ + # TODO: implement it based on torch ops + raise NotImplementedError diff --git a/vllm/lora/punica_wrapper/punica_cpu.py b/vllm/lora/punica_wrapper/punica_cpu.py index c51a13db873c..1a700d9bf1f0 100644 --- a/vllm/lora/punica_wrapper/punica_cpu.py +++ b/vllm/lora/punica_wrapper/punica_cpu.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional, Union +from collections.abc import Callable import torch @@ -30,7 +30,7 @@ def __init__( self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str], + device: torch.device | str, **kwargs, ): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) @@ -165,7 +165,7 @@ def _apply_shrink( def add_shrink( self, - y: Union[tuple[torch.Tensor, ...], torch.Tensor], + y: tuple[torch.Tensor, ...] | torch.Tensor, x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], scale: float, @@ -197,40 +197,32 @@ def add_shrink( def add_expand( self, y: torch.Tensor, - x: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: tuple[torch.Tensor, ...] | torch.Tensor, lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs, ) -> None: """ - Performs GEMM and bias addition for multiple slices of lora_b. + Performs GEMM for multiple slices of lora_b. Semantics: for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] offset += slice Args: y (torch.Tensor): Output tensor. x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): - bias's weight output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ y_org = y y = y.view(-1, y.shape[-1]) offset_left = offset_start - if lora_bias_stacked is not None: - self._apply_bias( - self.token_lora_indices, y, output_slices, lora_bias_stacked - ) for slice_idx in range(len(lora_b_stacked)): self._apply_expand( y, @@ -276,11 +268,10 @@ def add_lora_linear( x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, output_slices: tuple[int, ...], *, - buffer: Optional[tuple[torch.Tensor, ...]] = None, + buffer: tuple[torch.Tensor, ...] | None = None, **kwargs, ) -> None: """ @@ -293,25 +284,19 @@ def add_lora_linear( @ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :] * scale - ).squeeze(0)+lora_bias_stacked[i] + ).squeeze(0) Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. output_slices (tuple[int, ...]): Every slice's size. buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. """ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - if lora_bias_stacked is not None: - assert len(lora_bias_stacked) == len(output_slices) - y = self._apply_bias( - self.token_lora_indices, y, output_slices, lora_bias_stacked - ) if buffer is None: r = lora_b_stacked[0].size(-1) @@ -323,7 +308,7 @@ def add_lora_linear( ) self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) self.add_expand( - y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs + y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs ) def add_lora_logits( @@ -334,7 +319,7 @@ def add_lora_logits( lora_b_stacked: torch.Tensor, scale, *, - buffer: Optional[torch.Tensor] = None, + buffer: torch.Tensor | None = None, **kwargs, ) -> None: """ diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 431e97102faf..c2c26a01ee03 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -7,15 +7,23 @@ https://arxiv.org/abs/2310.18547 """ -from typing import Optional, Union, final +from typing import final import torch from vllm.lora.layers import LoRAMapping -from vllm.triton_utils import HAS_TRITON +from vllm.triton_utils import HAS_TRITON, triton +from vllm.utils import round_up if HAS_TRITON: - from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink + from vllm.lora.ops.triton_ops import ( + LoRAKernelMeta, + fused_moe_lora, + lora_expand, + lora_shrink, + ) + +from vllm import _custom_ops as ops from .punica_base import PunicaWrapperBase @@ -32,7 +40,7 @@ def __init__( self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str], + device: torch.device | str, **kwargs, ): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) @@ -50,7 +58,7 @@ def __init__( def update_metadata( self, mapping: LoRAMapping, - lora_index_to_id: list[Optional[int]], + lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, extra_vocab_size: int, @@ -101,36 +109,29 @@ def add_expand( y: torch.Tensor, x: torch.Tensor, lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs, ) -> None: """ - Performs GEMM and bias addition for multiple slices of lora_b. + Performs GEMM for multiple slices of lora_b. Semantics: for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] offset += slice Args: y (torch.Tensor): Output tensor. x (torch.Tensor): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): - bias's weight output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ y_org = y y = y.view(-1, y.shape[-1]) - if lora_bias_stacked is not None: - token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, y.size(0)) - self._apply_bias(token_lora_indices, y, output_slices, lora_bias_stacked) assert x.ndim == 3 assert x.size(0) == len(output_slices) @@ -183,11 +184,10 @@ def add_lora_linear( x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, output_slices: tuple[int, ...], *, - buffer: Optional[torch.Tensor] = None, + buffer: torch.Tensor | None = None, **kwargs, ) -> None: """ @@ -200,36 +200,31 @@ def add_lora_linear( @ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :] * scale - ).squeeze(0)+lora_bias_stacked[i] - + ).squeeze(0) Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. output_slices (tuple[int, ...]): Every slice's size. buffer (Optional[torch.Tensor]): Defaults to None. """ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - if lora_bias_stacked is not None: - assert len(lora_bias_stacked) == len(output_slices) - token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, y.size(0)) - y = self._apply_bias( - token_lora_indices, y, output_slices, lora_bias_stacked - ) - - if buffer is None: - r = lora_b_stacked[0].size(-1) - # We set the buffer to be float32 by default, refer to: - # https://github.com/triton-lang/triton/issues/1387 - buffer = torch.zeros( # type: ignore - (len(output_slices), x.size(0), r), - dtype=torch.float32, - device=x.device, - ) + + assert buffer is None, ( + "To minimize overhead, the buffer should be created by " + ".add_lora_linear() instead of being passed in." + ) + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default, refer to: + # https://github.com/triton-lang/triton/issues/1387 + # Note: buffer is zeroed inside the shrink op + buffer = torch.empty( + (len(output_slices), x.size(0), r), dtype=torch.float32, device=x.device + ) + self.add_shrink( buffer, # type: ignore x, @@ -241,7 +236,6 @@ def add_lora_linear( y, buffer, # type: ignore lora_b_stacked, - None, output_slices, add_inputs=True, **kwargs, @@ -255,7 +249,7 @@ def add_lora_logits( lora_b_stacked: torch.Tensor, scale, *, - buffer: Optional[torch.Tensor] = None, + buffer: torch.Tensor | None = None, **kwargs, ) -> None: """ @@ -277,10 +271,15 @@ def add_lora_logits( y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) r = lora_b_stacked.size(-1) - if buffer is None: - # We set the buffer to be float32 by default, refer to: - # https://github.com/triton-lang/triton/issues/1387 - buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) + + assert buffer is None, ( + "To minimize overhead, the buffer should be created by " + ".add_lora_linear() instead of being passed in." + ) + # We set the buffer to be float32 by default, refer to: + # https://github.com/triton-lang/triton/issues/1387 + # Note: buffer is zeroed inside the shrink op + buffer = torch.empty((x.size(0), r), dtype=torch.float32, device=x.device) lora_shrink( x, @@ -298,3 +297,93 @@ def add_lora_logits( add_inputs=True, ) y = y.view_as(y_org) + + def moe_lora_align_block_size( + self, + topk_ids: torch.Tensor, + num_tokens: int, + block_size: int, + num_experts: int, + max_loras: int, + expert_map: torch.Tensor | None = None, + pad_sorted_ids: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns tokens and experts into block-sized chunks for LoRA-based + mixture-of-experts (MoE) execution. + """ + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + sorted_ids = torch.empty( + (max_loras * max_num_tokens_padded,), + dtype=torch.int32, + device=topk_ids.device, + ) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + # Expert ids must be set default to -1 to prevent a blank block + expert_ids = torch.empty( + (max_loras * max_num_m_blocks,), + dtype=torch.int32, + device=topk_ids.device, + ) + num_tokens_post_pad = torch.empty( + (max_loras), dtype=torch.int32, device=topk_ids.device + ) + + (token_lora_mapping, _, _, _, _, _) = self.token_mapping_meta.meta_args( + num_tokens + ) + + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + if expert_map is not None: + expert_ids = expert_map[expert_ids] + + return sorted_ids, expert_ids, num_tokens_post_pad + + def add_lora_fused_moe( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: list[torch.Tensor], + lora_b_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + max_lora_rank: int, + top_k_num: int, + config, + mul_routed_weight=False, + ): + """ + Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer. + """ + fused_moe_lora( + y, + x, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + max_lora_rank, + top_k_num, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], + config["GROUP_SIZE_M"], + mul_routed_weight, + ) diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py index c017721803fe..d8763e913e3a 100644 --- a/vllm/lora/punica_wrapper/punica_selector.py +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -3,7 +3,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import resolve_obj_by_qualname +from vllm.utils.import_utils import resolve_obj_by_qualname from .punica_base import PunicaWrapperBase diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 5d2f05b815be..090878dcd254 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING import torch import torch.nn.functional as F @@ -29,7 +29,7 @@ def __init__( self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str], + device: torch.device | str, **kwargs, ): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) @@ -105,12 +105,12 @@ def expand_slice( def add_shrink( self, - y: Union[tuple[torch.Tensor, ...], torch.Tensor], + y: tuple[torch.Tensor, ...] | torch.Tensor, x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], scale: float, **kwargs, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """ Performs GEMM for multiple slices of lora_a. @@ -137,30 +137,26 @@ def add_shrink( def add_expand( self, y: torch.Tensor, - x: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: tuple[torch.Tensor, ...] | torch.Tensor, lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs, ) -> torch.Tensor: """ - Performs GEMM and bias addition for multiple slices of lora_b. + Performs GEMM for multiple slices of lora_b. Semantics: for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] offset += slice Args: y (torch.Tensor): Output tensor. x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): - bias's weight output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ @@ -168,10 +164,6 @@ def add_expand( y = y.view(-1, y.shape[-1]) offset_left = 0 - if lora_bias_stacked is not None: - y = self._apply_bias( - self._get_token_lora_indices(y), y, output_slices, lora_bias_stacked - ) for slice_idx in range(len(lora_b_stacked)): y = self.expand_slice( y, @@ -214,11 +206,10 @@ def add_lora_linear( x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, output_slices: tuple[int, ...], *, - buffer: Optional[tuple[torch.Tensor, ...]] = None, + buffer: tuple[torch.Tensor, ...] | None = None, **kwargs, ) -> torch.Tensor: """ @@ -231,25 +222,19 @@ def add_lora_linear( @ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :] * scale - ).squeeze(0)+lora_bias_stacked[i] + ).squeeze(0) Args: y (torch.Tensor): Output tensor. Will not be changed in-place. x (torch.Tensor): Input tensor (T, E) lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. output_slices (tuple[int, ...]): Every slice's size. buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. """ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - if lora_bias_stacked is not None: - assert len(lora_bias_stacked) == len(output_slices) - y = self._apply_bias( - self._get_token_lora_indices(y), y, output_slices, lora_bias_stacked - ) if buffer is None: r = lora_b_stacked[0].size(-1) @@ -261,7 +246,7 @@ def add_lora_linear( ) buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) return self.add_expand( - y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs + y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs ) def add_lora_logits( @@ -272,7 +257,7 @@ def add_lora_logits( lora_b_stacked: torch.Tensor, scale, *, - buffer: Optional[torch.Tensor] = None, + buffer: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: """ @@ -299,49 +284,12 @@ def add_lora_logits( y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True) return y.view_as(y_org) - def _apply_bias( - self, - indices: torch.Tensor, - output: torch.Tensor, - output_slices: tuple[int, ...], - lora_bias_stacked: tuple[Optional[torch.Tensor], ...], - ): - """Applies bias to output - - Input shapes: - lora_bias_stacked: 3 element tuple of (num_loras, output_dim) - indices: (batch_size) - output: (batch_size, q_slice_size + 2*kv_slice_size) - output_slices: n-1 element tuple of (slice_size...), - where n is number of slices - """ - org_output = output - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - - offset_left = 0 - for slice_idx, slice in enumerate(output_slices): - bias = lora_bias_stacked[slice_idx] - if bias is not None: - bias = bias.view(-1, bias.shape[-1]) - bias = bias[indices] - bias = torch.where(indices[:, None] == -1, 0, bias) - - bias = F.pad( - bias, (offset_left, output.shape[1] - (offset_left + slice), 0, 0) - ) - - output += bias - offset_left += slice - - return output.view_as(org_output) - # This performs the same tensor ops as the base method, except it does them # on the CPU then transfers the results to the TPU def _update_base_metadata( self, mapping: "LoRAMapping", - lora_index_to_id: list[Optional[int]], + lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, extra_vocab_size: int, diff --git a/vllm/lora/punica_wrapper/punica_xpu.py b/vllm/lora/punica_wrapper/punica_xpu.py index 5196199b2ac3..b95087d0ff83 100644 --- a/vllm/lora/punica_wrapper/punica_xpu.py +++ b/vllm/lora/punica_wrapper/punica_xpu.py @@ -7,7 +7,7 @@ https://arxiv.org/abs/2310.18547 """ -from typing import Optional, Union, final +from typing import final import torch @@ -29,7 +29,7 @@ def __init__( self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str], + device: torch.device | str, **kwargs, ): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) @@ -40,7 +40,7 @@ def __init__( def update_metadata( self, mapping: LoRAMapping, - lora_index_to_id: list[Optional[int]], + lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, extra_vocab_size: int, @@ -108,36 +108,29 @@ def add_expand( y: torch.Tensor, x: torch.Tensor, lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs, ) -> None: """ - Performs GEMM and bias addition for multiple slices of lora_b. + Performs GEMM for multiple slices of lora_b. Semantics: for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] offset += slice Args: y (torch.Tensor): Output tensor. x (torch.Tensor): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): - bias's weight output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ y_org = y y = y.view(-1, y.shape[-1]) - if lora_bias_stacked is not None: - token_lora_indices = self._get_token_lora_indices(y) - self._apply_bias(token_lora_indices, y, output_slices, lora_bias_stacked) assert x.ndim == 3 assert x.size(0) == len(output_slices) @@ -184,11 +177,10 @@ def add_lora_linear( x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, output_slices: tuple[int, ...], *, - buffer: Optional[torch.Tensor] = None, + buffer: torch.Tensor | None = None, **kwargs, ) -> None: """ @@ -201,26 +193,19 @@ def add_lora_linear( @ lora_a_stacked[indices[i], layer_idx, :, :] @ lora_b_stacked[indices[i], layer_idx, :, :] * scale - ).squeeze(0)+lora_bias_stacked[i] + ).squeeze(0) Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. output_slices (tuple[int, ...]): Every slice's size. buffer (Optional[torch.Tensor]): Defaults to None. """ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - if lora_bias_stacked is not None: - assert len(lora_bias_stacked) == len(output_slices) - token_lora_indices = self._get_token_lora_indices(y) - y = self._apply_bias( - token_lora_indices, y, output_slices, lora_bias_stacked - ) if buffer is None: r = lora_b_stacked[0].size(-1) @@ -242,7 +227,6 @@ def add_lora_linear( y, buffer, # type: ignore lora_b_stacked, - None, output_slices, add_inputs=True, **kwargs, @@ -263,7 +247,7 @@ def add_lora_logits( lora_b_stacked: torch.Tensor, scale, *, - buffer: Optional[torch.Tensor] = None, + buffer: torch.Tensor | None = None, **kwargs, ) -> None: """ diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index 90d1614e674d..584745f86b1a 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING import torch @@ -51,7 +51,7 @@ def compute_meta( # TODO see if this can be vectorized def convert_mapping( mapping: "LoRAMapping", - lora_index_to_id: list[Optional[int]], + lora_index_to_id: list[int | None], max_loras: int, vocab_size: int, extra_vocab_size: int, @@ -104,7 +104,7 @@ def convert_mapping( embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 lora_indices[i] = lora_idx - indices_list: list[Union[list[int], torch.Tensor]] = [ + indices_list: list[list[int] | torch.Tensor] = [ index_mapping_indices, lora_indices, embedding_indices, diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 650e060a5804..c97e435e3216 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import warnings -from typing import Optional import msgspec @@ -27,10 +26,10 @@ class LoRARequest( lora_name: str lora_int_id: int lora_path: str = "" - lora_local_path: Optional[str] = msgspec.field(default=None) - long_lora_max_len: Optional[int] = None - base_model_name: Optional[str] = msgspec.field(default=None) - tensorizer_config_dict: Optional[dict] = None + lora_local_path: str | None = msgspec.field(default=None) + long_lora_max_len: int | None = None + base_model_name: str | None = msgspec.field(default=None) + tensorizer_config_dict: dict | None = None def __post_init__(self): if self.lora_int_id < 1: diff --git a/vllm/lora/resolver.py b/vllm/lora/resolver.py index d366b94521cd..bcfe26467cfb 100644 --- a/vllm/lora/resolver.py +++ b/vllm/lora/resolver.py @@ -4,7 +4,6 @@ from abc import ABC, abstractmethod from collections.abc import Set from dataclasses import dataclass, field -from typing import Optional from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -24,7 +23,7 @@ class LoRAResolver(ABC): @abstractmethod async def resolve_lora( self, base_model_name: str, lora_name: str - ) -> Optional[LoRARequest]: + ) -> LoRARequest | None: """Abstract method to resolve and fetch a LoRA model adapter. Implements logic to locate and download LoRA adapter based on the name. diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 5e55d44ce8d9..0f43ff06d8f2 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional import huggingface_hub import regex as re @@ -23,6 +23,7 @@ BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, ColumnParallelLinearWithShardedLoRA, + FusedMoEWithLoRA, LogitsProcessorWithLoRA, MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithShardedLoRA, @@ -35,7 +36,9 @@ RowParallelLinearWithShardedLoRA, VocabParallelEmbeddingWithLoRA, ) +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.utils import get_moe_expert_mapping, get_packed_modules_mapping if TYPE_CHECKING: from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -58,15 +61,24 @@ MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA, + FusedMoEWithLoRA, } +def is_moe_model(model: nn.Module) -> bool: + """Checks if the model contains FusedMoE layers and warns the user.""" + if any(isinstance(module, FusedMoE) for module in model.modules()): + logger.info_once("MoE model detected. Using fused MoE LoRA implementation.") + return True + return False + + def from_layer( layer: nn.Module, max_loras: int, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig] = None, + model_config: PretrainedConfig | None = None, ) -> nn.Module: for lora_cls in _all_lora_classes: # specifying kwargs so they can be easily accessed in decorator @@ -87,7 +99,7 @@ def from_layer_logits_processor( lm_head: "ParallelLMHead", max_loras: int, lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, + model_config: PretrainedConfig | None = None, ) -> LogitsProcessorWithLoRA: ret = LogitsProcessorWithLoRA( layer, @@ -112,7 +124,7 @@ def replace_submodule( def parse_fine_tuned_lora_name( name: str, weights_mapper: Optional["WeightsMapper"] = None -) -> tuple[str, bool, bool]: +) -> tuple[str, bool]: """Parse the name of lora weights. args: @@ -124,7 +136,6 @@ def parse_fine_tuned_lora_name( tuple(module_name, is_lora_a): module_name: the name of the module, e.g. model.dense1, is_lora_a whether the tensor is lora_a or lora_b. - is_bias whether the tensor is lora bias. """ # LoRA weight qualified name usually starts with `base_model.model.`, @@ -146,21 +157,17 @@ def parse_fine_tuned_lora_name( parts = name.split(".") if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"): new_name = ".".join(parts[start_index:-2]) - return new_name, parts[-2] == "lora_A", False + return new_name, parts[-2] == "lora_A" if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": new_name = ".".join(parts[start_index:-1]) - return new_name, parts[-1] == "lora_embedding_A", False - - if parts[-1] == "bias": - new_name = ".".join(parts[start_index:-2]) - return new_name, False, True + return new_name, parts[-1] == "lora_embedding_A" raise ValueError(f"{name} is unsupported LoRA weight") def is_regex_target_modules( - load_modules: Union[str, list[str]], expected_lora_modules: list[str] + load_modules: str | list[str], expected_lora_modules: list[str] ) -> bool: """ PEFT supports passing `target_modules` in the form of regular expressions, @@ -210,6 +217,9 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]: if isinstance(module, (LinearBase,)): supported_lora_modules.add(name.split(".")[-1]) + if isinstance(module, (FusedMoE,)): + supported_lora_modules.add(name.split(".")[-1]) + return list(supported_lora_modules) @@ -257,3 +267,27 @@ def get_adapter_absolute_path(lora_path: str) -> str: return lora_path return local_snapshot_path + + +def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]: + if is_moe_model(model): + if moe_packed_mapping := get_moe_expert_mapping(model): + # This method generates and returns a dictionary mapping packed module + # names to lists of their corresponding submodule names. It includes + # both static mappings and dynamic mappings for expert layers, where + # the expert indices are expanded based on the configured number + # of routed experts. + packed_modules_mapping = get_packed_modules_mapping(model) + + packed_modules_mapping["experts"] = [ + weight_name.rstrip(".") for _, weight_name, _, _ in moe_packed_mapping + ] + + return packed_modules_mapping + else: + raise AttributeError( + "To support LoRA for MoE model, " + "'get_expert_mapping' must be implemented" + ) + else: + return get_packed_modules_mapping(model) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 3ca819fb732c..b85151f2c759 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager -from typing import Any, Literal, Optional, Union +from typing import Any, Literal import torch @@ -40,7 +40,7 @@ def __init__( self._lora_model_cls = lora_model_cls self.embedding_modules = embedding_modules self.embedding_padding_modules = embedding_padding_modules - self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False + self._cached_dummy_lora: None | Literal[False] | LoRAModel = False self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs self.max_num_batched_tokens = ( vllm_config.scheduler_config.max_num_batched_tokens @@ -94,7 +94,8 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: expected_lora_modules.extend(packed_modules_mapping[module]) else: expected_lora_modules.append(module) - + if module == "experts": + expected_lora_modules.append(module) expected_lora_modules = list(set(expected_lora_modules)) lora_path = get_adapter_absolute_path(lora_request.lora_path) @@ -166,7 +167,7 @@ def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: def pin_adapter(self, adapter_id: int) -> bool: return self._adapter_manager.pin_adapter(adapter_id) - def set_active_adapters(self, requests: set[Any], mapping: Optional[Any]) -> None: + def set_active_adapters(self, requests: set[Any], mapping: Any | None) -> None: self._apply_adapters(requests) if mapping is not None: self._adapter_manager.set_adapter_mapping(mapping) diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 6a0ea266378a..9ef696d80712 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch.nn as nn @@ -114,9 +113,9 @@ def enabled(cls) -> bool: custom_ops = compilation_config.custom_ops if not hasattr(cls, "name"): logger.warning_once( - "Custom op %s was not registered, which means it won't appear\ - in the op registry. It will be enabled/disabled based on the\ - global settings.", # noqa: E501 + "Custom op %s was not registered, which means it won't appear " + "in the op registry. It will be enabled/disabled based on the " + "global settings.", cls.__name__, ) return CustomOp.default_on() @@ -171,7 +170,7 @@ def decorator(op_cls): # or # - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod") @classmethod - def register_oot(cls, _decorated_op_cls=None, name: Optional[str] = None): + def register_oot(cls, _decorated_op_cls=None, name: str | None = None): def decorator(op_cls): reg_name = name if name is not None else cls.__name__ assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}" diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 96745b99f7a7..3471ee327cf8 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -3,7 +3,6 @@ """Custom activation functions.""" import math -from typing import Optional import torch import torch.nn as nn @@ -18,7 +17,7 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import LazyDict +from vllm.utils.collection_utils import LazyDict logger = init_logger(__name__) @@ -81,7 +80,8 @@ def __init__(self): elif current_platform.is_cpu(): self._forward_method = self.forward_native - def forward_native(self, x: torch.Tensor) -> torch.Tensor: + @staticmethod + def forward_native(x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:] @@ -486,7 +486,7 @@ def __init__( act_module: nn.Module, intermediate_size: int, input_is_parallel: bool = True, - params_dtype: Optional[torch.dtype] = None, + params_dtype: torch.dtype | None = None, ): super().__init__() self.act = act_module diff --git a/vllm/model_executor/layers/attention_layer_base.py b/vllm/model_executor/layers/attention_layer_base.py index fa74c20840da..ffbef470b186 100644 --- a/vllm/model_executor/layers/attention_layer_base.py +++ b/vllm/model_executor/layers/attention_layer_base.py @@ -5,6 +5,9 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING +from vllm.config import VllmConfig +from vllm.v1.kv_cache_interface import KVCacheSpec + if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -22,3 +25,11 @@ class AttentionLayerBase(ABC): def get_attn_backend(self) -> type["AttentionBackend"]: """Get the attention backend class for this layer.""" pass + + @abstractmethod + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: + """ + Get the KV cache spec for this layer. + May be None if the layer does not need KV cache. + """ + pass diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 9fd85d1e9e19..7368bfd35fec 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -4,12 +4,16 @@ import os from collections import namedtuple from collections.abc import Callable -from typing import Any, Union +from typing import Any import torch +import vllm.envs as envs +from vllm.logger import init_logger from vllm.triton_utils import tl, triton +logger = init_logger(__name__) + def _matmul_launch_metadata( grid: Callable[..., Any], kernel: Any, args: dict[str, Any] @@ -130,15 +134,12 @@ def matmul_kernel_persistent( bias_ptrs = bias_ptr + offs_cn bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32) accumulator += bias - if c_ptr.dtype.element_ty == tl.float8e4nv: - c = accumulator.to(tl.float8e4nv) - else: - c = accumulator.to(tl.float16) + c = accumulator.to(c_ptr.dtype.element_ty) tl.store(c_ptrs, c, mask=c_mask) def matmul_persistent( - a: torch.Tensor, b: torch.Tensor, bias: Union[torch.Tensor, None] = None + a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None ): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" @@ -375,7 +376,7 @@ def mean_dim( input: torch.Tensor, dim: int, keepdim: bool = False, - dtype: Union[torch.dtype, None] = None, + dtype: torch.dtype | None = None, ) -> torch.Tensor: """ Triton implementation of torch.mean with single dimension reduction. @@ -391,7 +392,6 @@ def mean_dim( Tensor with mean values along specified dimension """ # Validate inputs - assert input.is_cuda, "Input must be a CUDA tensor" assert -input.ndim <= dim < input.ndim, ( f"Invalid dimension {dim} for tensor with {input.ndim} dimensions" ) @@ -466,6 +466,45 @@ def mm_batch_invariant(a, b): return matmul_persistent(a, b) +def matmul_batch_invariant(a, b, *, out=None): + # torch.matmul can handle various dimensions + # For 2D x 2D, it's the same as mm + if a.ndim == 2 and b.ndim == 2: + result = matmul_persistent(a, b) + if out is not None: + out.copy_(result) + return out + return result + elif a.ndim == 3 and b.ndim == 3: + # Handle batched case like bmm + return bmm_batch_invariant(a, b, out=out) + else: + raise ValueError( + f"matmul_batch_invariant currently only supports 2D x 2D and 3D x 3D, " + f"got shapes {a.shape} and {b.shape}" + ) + + +def bmm_batch_invariant(a, b, *, out=None): + # Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N) + # Process each batch separately with our persistent kernel + if a.ndim == 3 and b.ndim == 3: + results = [] + for i in range(a.shape[0]): + results.append(matmul_persistent(a[i], b[i])) + result = torch.stack(results, dim=0) + + if out is not None: + out.copy_(result) + return out + return result + else: + raise ValueError( + f"bmm_batch_invariant expects 3D tensors, " + f"got shapes {a.shape} and {b.shape}" + ) + + def addmm_batch_invariant(bias, a, b): return matmul_persistent(a, b, bias=bias) @@ -475,13 +514,24 @@ def _log_softmax_batch_invariant(input, dim, _half_to_float): return log_softmax(input, dim=dim) -def mean_batch_invariant( - input, dim, keepdim=False, dtype: Union[torch.dtype, None] = None -): +def softmax_batch_invariant(input, dim, dtype=None): + # Compute softmax in a deterministic way + # First subtract max for numerical stability (standard practice) + input_max = torch.amax(input, dim=dim, keepdim=True) + input = input - input_max + exp_x = torch.exp(input) + sum_exp_x = torch.sum(exp_x, dim=dim, keepdim=True) + return exp_x / sum_exp_x + + +def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None): assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}" result = input.to(torch.float32) + if len(dim) == 0: + dim = [i for i in range(len(input.shape))] + # Sort dimensions to reduce from largest to smallest to handle shifting dims # during iterative reduction. sorted_dims = sorted([d % input.ndim for d in dim], reverse=True) @@ -498,8 +548,134 @@ def mean_batch_invariant( return result +@triton.jit +def _rms_norm_kernel( + input_ptr, + weight_ptr, + output_ptr, + input_row_stride, + output_row_stride, + n_cols, + eps, + BLOCK_SIZE: tl.constexpr, +): + """ + Compute RMS normalization along the last dimension of a 2D tensor. + RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight + Each block handles one row of the input tensor. + """ + row_idx = tl.program_id(0).to(tl.int64) + row_start_ptr = input_ptr + row_idx * input_row_stride + output_row_start_ptr = output_ptr + row_idx * output_row_stride + + # Step 1: Compute sum of squares in float32 to avoid overflow + sum_sq = tl.zeros([1], dtype=tl.float32) + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) + # Convert to float32 for accumulation to prevent overflow + vals_f32 = vals.to(tl.float32) + sq_vals = vals_f32 * vals_f32 + sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0)) + + # Step 2: Compute RMS (root mean square) in float32 + mean_sq = sum_sq / n_cols + rms = tl.sqrt(mean_sq + eps) + inv_rms = 1.0 / rms + + # Step 3: Normalize and apply weight + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) + weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0) + # Compute in float32 then convert back to input dtype + vals_f32 = vals.to(tl.float32) + weight_f32 = weight.to(tl.float32) + output_f32 = vals_f32 * inv_rms * weight_f32 + output = output_f32.to(vals.dtype) + tl.store(output_row_start_ptr + col_idx, output, mask=mask) + + +def rms_norm( + input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + """ + Compute RMS normalization using Triton kernel. + + RMS Norm normalizes the input by the root mean square and scales by weight: + output = input / sqrt(mean(input^2) + eps) * weight + + Args: + input: Input tensor of shape (..., hidden_size) + weight: Weight tensor of shape (hidden_size,) + eps: Small constant for numerical stability + + Returns: + Tensor with RMS normalization applied along the last dimension + """ + assert weight.dim() == 1, "Weight must be 1-dimensional" + assert input.shape[-1] == weight.shape[0], ( + f"Input last dimension ({input.shape[-1]}) must match " + f"weight dimension ({weight.shape[0]})" + ) + + # Flatten all dimensions except the last one + original_shape = input.shape + input_2d = input.reshape(-1, input.shape[-1]) + input_2d = input_2d.contiguous() + weight = weight.contiguous() + + n_rows, n_cols = input_2d.shape + + output = torch.empty_like(input_2d) + BLOCK_SIZE = 1024 + grid = (n_rows,) + _rms_norm_kernel[grid]( + input_2d, + weight, + output, + input_2d.stride(0), + output.stride(0), + n_cols, + eps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return output.reshape(original_shape) + + +def rms_norm_batch_invariant( + input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + """ + Batch-invariant wrapper for RMS normalization. + + This function provides a deterministic, batch-invariant implementation + of RMS normalization for use with the batch_invariant mode. + + Args: + input: Input tensor of shape (..., hidden_size) + weight: Weight tensor of shape (hidden_size,) + eps: Small constant for numerical stability + + Returns: + RMS normalized tensor + """ + return rms_norm(input, weight, eps=eps) + + +def linear_batch_invariant(input, weight, bias=None): + output = mm_batch_invariant(input, weight.t()) + if bias is not None: + output = output + bias + return output + + _batch_invariant_MODE = False _batch_invariant_LIB = None +_original_torch_bmm = None def is_batch_invariant_mode_enabled(): @@ -507,7 +683,7 @@ def is_batch_invariant_mode_enabled(): def enable_batch_invariant_mode(): - global _batch_invariant_MODE, _batch_invariant_LIB + global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm if _batch_invariant_MODE: return @@ -515,16 +691,28 @@ def enable_batch_invariant_mode(): _batch_invariant_LIB = torch.library.Library("aten", "IMPL") _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA") _batch_invariant_LIB.impl( "aten::_log_softmax", _log_softmax_batch_invariant, "CUDA" ) + _batch_invariant_LIB.impl("aten::softmax", softmax_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::_softmax", softmax_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA") + # Also monkeypatch torch.bmm directly as a fallback + _original_torch_bmm = torch.bmm + torch.bmm = bmm_batch_invariant + def disable_batch_invariant_mode(): - global _batch_invariant_MODE, _batch_invariant_LIB + global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm if _batch_invariant_LIB is not None: _batch_invariant_LIB._destroy() + if _original_torch_bmm is not None: + torch.bmm = _original_torch_bmm + _original_torch_bmm = None _batch_invariant_MODE = False _batch_invariant_LIB = None @@ -550,8 +738,8 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize: return AttentionBlockSize(block_m=16, block_n=16) -def vllm_kernel_override_batch_invariant(): - env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT" +def vllm_is_batch_invariant(): + env_key = "VLLM_BATCH_INVARIANT" is_overridden = False val = os.getenv(env_key, "0") try: @@ -561,8 +749,55 @@ def vllm_kernel_override_batch_invariant(): return is_overridden +def override_envs_for_invariance(): + curr_attn_backend = envs.VLLM_ATTENTION_BACKEND + supported_backends = [ + "FLASH_ATTN", # best supported backend + "FLEX_ATTENTION", + "FLASHINFER", + "FLASH_ATTN_MLA", + "FLASHINFER_MLA", + "TRITON_MLA", + # Not yet supported MLA backends + # "FLASHMLA", + ] + if curr_attn_backend not in supported_backends: + warning = ( + "Forcibly updating attention backend to" + f" {supported_backends[0]} for batch_invariant. " + f" Supported backends: {supported_backends}." + ) + logger.warning_once(warning) + os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0] + if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]: + warning = ( + "You are using a decode-invariant form of batch invariance. " + "This will not be invariant between prefill and decode." + ) + logger.warning_once(warning) + os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0" + + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + # NCCL determinism settings + os.environ["NCCL_LAUNCH_MODE"] = "GROUP" + os.environ["NCCL_COLLNET_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = "0" + os.environ["NCCL_P2P_NET_DISABLE"] = "1" + os.environ["NCCL_MIN_NCHANNELS"] = "1" + os.environ["NCCL_MAX_NCHANNELS"] = "1" + os.environ["NCCL_PROTO"] = "Simple" + os.environ["NCCL_ALGO"] = "allreduce:tree" + os.environ["NCCL_NTHREADS"] = "1" + os.environ["NCCL_SOCKET_NTHREADS"] = "1" + + def init_batch_invariance(): # this will hit all the csrc overrides as well - if vllm_kernel_override_batch_invariant(): - os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION" + if vllm_is_batch_invariant(): + override_envs_for_invariance() enable_batch_invariant_mode() + + # Disable TF32 for batch invariance - it causes non-deterministic rounding + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py index d65c87aba11c..b046a6d3919e 100644 --- a/vllm/model_executor/layers/fla/ops/chunk.py +++ b/vllm/model_executor/layers/fla/ops/chunk.py @@ -8,7 +8,6 @@ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 import warnings -from typing import Optional import torch from einops import rearrange @@ -32,7 +31,7 @@ def chunk_gated_delta_rule_fwd( scale: float, initial_state: torch.Tensor, output_final_state: bool, - cu_seqlens: Optional[torch.LongTensor] = None, + cu_seqlens: torch.LongTensor | None = None, ): g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) # obtain WY representation. u is actually the new v. @@ -86,7 +85,7 @@ def forward( scale: float, initial_state: torch.Tensor, output_final_state: bool, - cu_seqlens: Optional[torch.LongTensor] = None, + cu_seqlens: torch.LongTensor | None = None, use_qk_l2norm_in_kernel: bool = False, ): if use_qk_l2norm_in_kernel: @@ -119,7 +118,7 @@ def chunk_gated_delta_rule( scale: float = None, initial_state: torch.Tensor = None, output_final_state: bool = False, - cu_seqlens: Optional[torch.LongTensor] = None, + cu_seqlens: torch.LongTensor | None = None, head_first: bool = False, use_qk_l2norm_in_kernel: bool = False, ): diff --git a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py index 817962d9c946..1c14f84c2b89 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py +++ b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py @@ -7,7 +7,6 @@ # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 -from typing import Optional import torch @@ -257,12 +256,12 @@ def chunk_gated_delta_rule_fwd_h( k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, - g: Optional[torch.Tensor] = None, - initial_state: Optional[torch.Tensor] = None, + g: torch.Tensor | None = None, + initial_state: torch.Tensor | None = None, output_final_state: bool = False, chunk_size: int = 64, # SY: remove this argument and force chunk size 64? save_new_value: bool = True, - cu_seqlens: Optional[torch.LongTensor] = None, + cu_seqlens: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: B, T, Hg, K, V = *k.shape, u.shape[-1] H = u.shape[-2] diff --git a/vllm/model_executor/layers/fla/ops/chunk_o.py b/vllm/model_executor/layers/fla/ops/chunk_o.py index ae404a3615f6..4e8e04c1d48c 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_o.py +++ b/vllm/model_executor/layers/fla/ops/chunk_o.py @@ -9,7 +9,6 @@ # ruff: noqa: E501 -from typing import Optional import torch @@ -144,9 +143,9 @@ def chunk_fwd_o( k: torch.Tensor, v: torch.Tensor, h: torch.Tensor, - g: Optional[torch.Tensor] = None, # cumsum of log decay - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, + g: torch.Tensor | None = None, # cumsum of log decay + scale: float | None = None, + cu_seqlens: torch.LongTensor | None = None, chunk_size: int = 64, ) -> torch.Tensor: B, T, Hg, K, V = *q.shape, v.shape[-1] diff --git a/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py index 0da3f243901f..975e119af333 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py +++ b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py @@ -7,7 +7,6 @@ # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 -from typing import Optional import torch @@ -104,8 +103,8 @@ def chunk_scaled_dot_kkt_fwd_kernel( def chunk_scaled_dot_kkt_fwd( k: torch.Tensor, beta: torch.Tensor, - g_cumsum: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.LongTensor] = None, + g_cumsum: torch.Tensor | None = None, + cu_seqlens: torch.LongTensor | None = None, chunk_size: int = 64, output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: diff --git a/vllm/model_executor/layers/fla/ops/cumsum.py b/vllm/model_executor/layers/fla/ops/cumsum.py index cfa2b3b48e70..99b41794796d 100644 --- a/vllm/model_executor/layers/fla/ops/cumsum.py +++ b/vllm/model_executor/layers/fla/ops/cumsum.py @@ -8,7 +8,6 @@ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 import warnings -from typing import Optional import torch @@ -163,9 +162,9 @@ def chunk_local_cumsum_scalar( g: torch.Tensor, chunk_size: int, reverse: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, + cu_seqlens: torch.Tensor | None = None, head_first: bool = False, - output_dtype: Optional[torch.dtype] = torch.float, + output_dtype: torch.dtype | None = torch.float, ) -> torch.Tensor: if head_first: B, H, T = g.shape @@ -200,9 +199,9 @@ def chunk_local_cumsum_vector( g: torch.Tensor, chunk_size: int, reverse: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, + cu_seqlens: torch.Tensor | None = None, head_first: bool = False, - output_dtype: Optional[torch.dtype] = torch.float, + output_dtype: torch.dtype | None = torch.float, ) -> torch.Tensor: if head_first: B, H, T, S = g.shape @@ -248,9 +247,9 @@ def chunk_local_cumsum( g: torch.Tensor, chunk_size: int, reverse: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, + cu_seqlens: torch.Tensor | None = None, head_first: bool = False, - output_dtype: Optional[torch.dtype] = torch.float, + output_dtype: torch.dtype | None = torch.float, **kwargs, ) -> torch.Tensor: if not head_first and g.shape[1] < g.shape[2]: diff --git a/vllm/model_executor/layers/fla/ops/fused_recurrent.py b/vllm/model_executor/layers/fla/ops/fused_recurrent.py index fa10bdb36caa..f3de1bfa2821 100644 --- a/vllm/model_executor/layers/fla/ops/fused_recurrent.py +++ b/vllm/model_executor/layers/fla/ops/fused_recurrent.py @@ -7,7 +7,6 @@ # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 -from typing import Optional import torch @@ -169,9 +168,9 @@ def fused_recurrent_gated_delta_rule_fwd( scale: float, initial_state: torch.Tensor, inplace_final_state: bool = True, - cu_seqlens: Optional[torch.LongTensor] = None, - ssm_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: B, T, H, K, V = *k.shape, v.shape[-1] @@ -248,9 +247,9 @@ def forward( scale: float, initial_state: torch.Tensor, inplace_final_state: bool = True, - cu_seqlens: Optional[torch.LongTensor] = None, - ssm_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, ): o, final_state = fused_recurrent_gated_delta_rule_fwd( @@ -280,9 +279,9 @@ def fused_recurrent_gated_delta_rule( scale: float = None, initial_state: torch.Tensor = None, inplace_final_state: bool = True, - cu_seqlens: Optional[torch.LongTensor] = None, - ssm_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: r""" diff --git a/vllm/model_executor/layers/fla/ops/l2norm.py b/vllm/model_executor/layers/fla/ops/l2norm.py index 315dd904523b..4d7dbb510068 100644 --- a/vllm/model_executor/layers/fla/ops/l2norm.py +++ b/vllm/model_executor/layers/fla/ops/l2norm.py @@ -8,7 +8,6 @@ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang import os -from typing import Optional import torch @@ -90,7 +89,7 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): def l2norm_fwd( - x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None + x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None ): x_shape_og = x.shape x = x.view(-1, x.shape[-1]) diff --git a/vllm/model_executor/layers/fla/ops/layernorm_guard.py b/vllm/model_executor/layers/fla/ops/layernorm_guard.py index 655cdb3f30eb..307d0859c24e 100644 --- a/vllm/model_executor/layers/fla/ops/layernorm_guard.py +++ b/vllm/model_executor/layers/fla/ops/layernorm_guard.py @@ -13,7 +13,7 @@ # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. -from typing import Optional +from functools import lru_cache import torch import torch.nn as nn @@ -21,6 +21,7 @@ from einops import rearrange from vllm.triton_utils import tl, triton +from vllm.utils import cdiv, next_power_of_2 from .utils import input_guard @@ -76,55 +77,103 @@ def layer_norm_fwd_kernel( stride_y_row, stride_z_row, M, # number of rows in X - N, # number of columns in X + N: tl.constexpr, # number of columns in X eps, # epsilon to avoid division by zero BLOCK_N: tl.constexpr, + ROWS_PER_BLOCK: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_Z: tl.constexpr, NORM_BEFORE_GATE: tl.constexpr, IS_RMS_NORM: tl.constexpr, ): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) + # Map the program id to the starting row of X and Y it should compute. + row_start = tl.program_id(0) * ROWS_PER_BLOCK group = tl.program_id(1) - X += row * stride_x_row + group * N - Y += row * stride_y_row + group * N - if HAS_Z: - Z += row * stride_z_row + group * N - if not IS_RMS_NORM: - Mean += group * M - Rstd += group * M - W += group * N - if HAS_BIAS: - B += group * N - # Compute mean and variance + + # Create 2D tile: [ROWS_PER_BLOCK, BLOCK_N] + rows = row_start + tl.arange(0, ROWS_PER_BLOCK) cols = tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + + # Compute offsets for 2D tile + row_offsets = rows[:, None] * stride_x_row + col_offsets = cols[None, :] + group * N + + # Base pointers + X_base = X + row_offsets + col_offsets + Y_base = Y + rows[:, None] * stride_y_row + col_offsets + + # Create mask for valid rows and columns + row_mask = rows[:, None] < M + col_mask = cols[None, :] < N + mask = row_mask & col_mask + + # Load input data with 2D tile + x = tl.load(X_base, mask=mask, other=0.0).to(tl.float32) + if HAS_Z and not NORM_BEFORE_GATE: - z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + Z_base = Z + rows[:, None] * stride_z_row + col_offsets + z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32) x *= z * tl.sigmoid(z) + + # Compute mean and variance per row (reduce along axis 1) if not IS_RMS_NORM: - mean = tl.sum(x, axis=0) / N - tl.store(Mean + row, mean) - xbar = tl.where(cols < N, x - mean, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N + mean = tl.sum(x, axis=1) / N # Shape: [ROWS_PER_BLOCK] + # Store mean for each row + mean_offsets = group * M + rows + mean_mask = rows < M + tl.store(Mean + mean_offsets, mean, mask=mean_mask) + # Broadcast mean back to 2D for subtraction + xbar = tl.where(mask, x - mean[:, None], 0.0) + var = tl.sum(xbar * xbar, axis=1) / N # Shape: [ROWS_PER_BLOCK] else: - xbar = tl.where(cols < N, x, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) - # Normalize and apply linear transformation - mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) + xbar = tl.where(mask, x, 0.0) + var = tl.sum(xbar * xbar, axis=1) / N # Shape: [ROWS_PER_BLOCK] + mean = 0.0 # Placeholder for RMS norm + + rstd = tl.rsqrt(var + eps) # Shape: [ROWS_PER_BLOCK] + + # Store rstd for each row + rstd_offsets = group * M + rows + rstd_mask = rows < M + tl.store(Rstd + rstd_offsets, rstd, mask=rstd_mask) + + # Load weights and biases (broadcast across rows) + w_offsets = cols + group * N + w_mask = cols < N + w = tl.load(W + w_offsets, mask=w_mask, other=0.0).to(tl.float32) + if HAS_BIAS: - b = tl.load(B + cols, mask=mask).to(tl.float32) - x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - y = x_hat * w + b if HAS_BIAS else x_hat * w + b = tl.load(B + w_offsets, mask=w_mask, other=0.0).to(tl.float32) + + # Normalize and apply linear transformation + if not IS_RMS_NORM: + x_hat = (x - mean[:, None]) * rstd[:, None] + else: + x_hat = x * rstd[:, None] + + y = x_hat * w[None, :] + b[None, :] if HAS_BIAS else x_hat * w[None, :] + if HAS_Z and NORM_BEFORE_GATE: - z = tl.load(Z + cols, mask=mask).to(tl.float32) + Z_base = Z + rows[:, None] * stride_z_row + col_offsets + z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32) y *= z * tl.sigmoid(z) + # Write output - tl.store(Y + cols, y, mask=mask) + tl.store(Y_base, y, mask=mask) + + +@lru_cache +def _get_sm_count(device: torch.device) -> int: + """Get and cache the SM count for a given device.""" + props = torch.cuda.get_device_properties(device) + return props.multi_processor_count + + +def calc_rows_per_block(M: int, device: torch.device) -> int: + sm_count = _get_sm_count(device) + rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count)) + rows_per_block = min(rows_per_block, 4) + return rows_per_block def layer_norm_fwd( @@ -171,7 +220,10 @@ def layer_norm_fwd( raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps num_warps = min(max(BLOCK_N // 256, 1), 8) - grid = (M, ngroups) + # Calculate rows per block based on SM count + rows_per_block = calc_rows_per_block(M, x.device) + # Update grid to use rows_per_block + grid = (cdiv(M, rows_per_block), ngroups) layer_norm_fwd_kernel[grid]( x, out, @@ -187,6 +239,7 @@ def layer_norm_fwd( group_size, eps, BLOCK_N=BLOCK_N, + ROWS_PER_BLOCK=rows_per_block, NORM_BEFORE_GATE=norm_before_gate, IS_RMS_NORM=is_rms_norm, num_warps=num_warps, @@ -270,10 +323,10 @@ def __init__( self, hidden_size, eps: float = 1e-5, - group_size: Optional[int] = None, + group_size: int | None = None, norm_before_gate: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ): """If group_size is not None, we do GroupNorm with each group having group_size elements. group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). @@ -310,10 +363,10 @@ def __init__( self, hidden_size, eps: float = 1e-5, - group_size: Optional[int] = None, + group_size: int | None = None, norm_before_gate: bool = False, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ): """If group_size is not None, we do GroupNorm with each group having group_size elements. group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). diff --git a/vllm/model_executor/layers/fla/ops/op.py b/vllm/model_executor/layers/fla/ops/op.py index ee2f4185a5df..a91975c8e567 100644 --- a/vllm/model_executor/layers/fla/ops/op.py +++ b/vllm/model_executor/layers/fla/ops/op.py @@ -11,29 +11,50 @@ from vllm.triton_utils import tl, tldevice, triton +from .utils import is_gather_supported + if os.environ.get("FLA_USE_FAST_OPS", "0") == "1": - div = tldevice.fast_dividef exp = tldevice.fast_expf log = tldevice.fast_logf log2 = tldevice.fast_log2f else: - - @triton.jit - def div_normal(x, y): - return x / y - - div = div_normal exp = tl.exp log = tl.log log2 = tl.log2 -if not hasattr(tl, "gather"): +if not is_gather_supported: @triton.jit def gather(src, index, axis, _builder=None): - # This is a fallback implementation when tl.gather is not supported - # In order to pass triton compiler, there is no actual gather operation - return src + """ + Gather operation that works when tl.gather is not supported. + This is a fallback implementation that returns None. + Just to make triton compiler happy. + """ + return None else: gather = tl.gather + +if hasattr(triton.language, "_experimental_make_tensor_descriptor"): + # For Triton 3.3.x + make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor +elif hasattr(triton.language, "make_tensor_descriptor"): + # For Triton 3.4.x and later + make_tensor_descriptor = triton.language.make_tensor_descriptor +else: + """ + Fallback implementation when TMA is not supported. + Returns None to indicate TMA descriptors are unavailable. + Just make triton compiler happy. + """ + + @triton.jit + def make_tensor_descriptor( + base, + shape, + strides, + block_shape, + _builder=None, + ): + return None diff --git a/vllm/model_executor/layers/fla/ops/solve_tril.py b/vllm/model_executor/layers/fla/ops/solve_tril.py index d30fea90aec3..da85aab19207 100644 --- a/vllm/model_executor/layers/fla/ops/solve_tril.py +++ b/vllm/model_executor/layers/fla/ops/solve_tril.py @@ -7,14 +7,22 @@ # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 -from typing import Optional + +import os import torch from vllm.triton_utils import tl, triton from .index import prepare_chunk_indices -from .utils import input_guard +from .op import make_tensor_descriptor +from .utils import input_guard, is_amd, is_tma_supported + +FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee") +ALLOWED_TRIL_PRECISIONS = ["ieee", "tf32"] if is_amd else ["ieee", "tf32", "tf32x3"] +assert FLA_TRIL_PRECISION in ALLOWED_TRIL_PRECISIONS, ( + f"FLA_TRIL_PRECISION must be one of {ALLOWED_TRIL_PRECISIONS}, but got {FLA_TRIL_PRECISION}" +) @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) @@ -29,13 +37,15 @@ @triton.jit(do_not_specialize=["T"]) def solve_tril_16x16_kernel( A, - Ad, + Ai, cu_seqlens, chunk_indices, T, H: tl.constexpr, BT: tl.constexpr, + USE_TMA: tl.constexpr, IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H @@ -51,30 +61,43 @@ def solve_tril_16x16_kernel( T = eos - bos else: bos, eos = i_b * T, i_b * T + T + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] A = A + (bos * H + i_h) * BT - Ad = Ad + (bos * H + i_h) * 16 + Ai = Ai + (bos * H + i_h) * 16 offset = (i_t * 16) % BT - p_A = tl.make_block_ptr( - A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0) - ) - p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0)) - b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) - b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0) + if not USE_TMA: + p_A = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0) + ) + # [16, 16] + b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, 16], [H * 16, 1], [16, 16]) + b_A = desc.load([i_t * 16, offset]).to(tl.float32) + b_A = -tl.where(m_A, b_A, 0) - o_i = tl.arange(0, 16) - for i in range(1, min(16, T - i_t * 16)): + for i in range(2, min(16, T - i_t * 16)): + # [16] b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset) b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) - mask = o_i == i - b_A = tl.where(mask[:, None], b_a, b_A) - b_A += o_i[:, None] == o_i[None, :] - tl.store( - p_Ai, - b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), - ) + b_A = tl.where((o_i == i)[:, None], b_a, b_A) + b_A += m_I + if not USE_TMA: + p_Ai = tl.make_block_ptr( + Ai, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0) + ) + tl.store( + p_Ai, + b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + else: + desc_o.store([i_t * 16, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne")) @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) @@ -89,14 +112,15 @@ def solve_tril_16x16_kernel( @triton.jit(do_not_specialize=["T"]) def merge_16x16_to_32x32_inverse_kernel( A, - Ad, Ai, cu_seqlens, chunk_indices, T, H: tl.constexpr, BT: tl.constexpr, + USE_TMA: tl.constexpr, IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H @@ -113,51 +137,93 @@ def merge_16x16_to_32x32_inverse_kernel( else: bos, eos = i_b * T, i_b * T + T - A += (bos * H + i_h) * 32 - Ad += (bos * H + i_h) * 16 - Ai += (bos * H + i_h) * 32 + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + A += (bos * H + i_h) * BT + Ai += (bos * H + i_h) * BT - p_A_21 = tl.make_block_ptr( - A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0) - ) - p_Ad_11 = tl.make_block_ptr( - Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0) - ) - p_Ad_22 = tl.make_block_ptr( - Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0) - ) - p_Ai_11 = tl.make_block_ptr( - Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0) - ) - p_Ai_22 = tl.make_block_ptr( - Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0) - ) - p_Ai_21 = tl.make_block_ptr( - Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0) - ) + if not USE_TMA: + p_A_11 = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0) + ) + p_A_22 = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0) + ) + b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32) + b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16]) + b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32) + b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32) - A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) - Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) - Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) - Ai_21 = -tl.dot( - tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee" - ) - tl.store( - p_Ai_11, - Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), - ) - tl.store( - p_Ai_22, - Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), - ) - tl.store( - p_Ai_21, - Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), + # [16, 16] + b_Ai_11 = -tl.where(m_A, b_Ai_11, 0) + b_Ai_22 = -tl.where(m_A, b_Ai_22, 0) + + for i in range(2, min(16, T - i_t * BT)): + b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i) + b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0) + b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11) + for i in range(16 + 2, min(32, T - i_t * BT)): + b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16) + b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0) + b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22) + + b_Ai_11 += m_I + b_Ai_22 += m_I + + if not USE_TMA: + p_A_21 = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0) + ) + b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + else: + b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32) + + b_Ai_21 = -tl.dot( + tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), + b_Ai_11, + input_precision=DOT_PRECISION, ) + if not USE_TMA: + p_Ai_11 = tl.make_block_ptr( + Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0) + ) + p_Ai_21 = tl.make_block_ptr( + Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0) + ) + p_Ai_22 = tl.make_block_ptr( + Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0) + ) + tl.store( + p_Ai_11, + b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + else: + desc_o.store( + [i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne") + ) + desc_o.store( + [i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne") + ) + desc_o.store( + [i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne") + ) + @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) @triton.autotune( @@ -171,14 +237,15 @@ def merge_16x16_to_32x32_inverse_kernel( @triton.jit(do_not_specialize=["T"]) def merge_16x16_to_64x64_inverse_kernel( A, - Ad, Ai, cu_seqlens, chunk_indices, T, H: tl.constexpr, BT: tl.constexpr, + USE_TMA: tl.constexpr, IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H @@ -195,278 +262,295 @@ def merge_16x16_to_64x64_inverse_kernel( else: bos, eos = i_b * T, i_b * T + T - A += (bos * H + i_h) * 64 - Ad += (bos * H + i_h) * 16 - Ai += (bos * H + i_h) * 64 + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + A += (bos * H + i_h) * BT + Ai += (bos * H + i_h) * BT - p_A_21 = tl.make_block_ptr( - A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0) - ) - p_A_32 = tl.make_block_ptr( - A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0) - ) - p_A_31 = tl.make_block_ptr( - A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0) - ) - p_A_43 = tl.make_block_ptr( - A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0) - ) - p_A_42 = tl.make_block_ptr( - A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0) - ) - p_A_41 = tl.make_block_ptr( - A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0) - ) - p_Ad_11 = tl.make_block_ptr( - Ad, (T, 16), (H * 16, 1), (i_t * 64, 0), (16, 16), (1, 0) - ) - p_Ad_22 = tl.make_block_ptr( - Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0) - ) - p_Ad_33 = tl.make_block_ptr( - Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0) - ) - p_Ad_44 = tl.make_block_ptr( - Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0) - ) + if not USE_TMA: + p_A_11 = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0) + ) + p_A_22 = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0) + ) + p_A_33 = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0) + ) + p_A_44 = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0) + ) + b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32) + b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32) + b_Ai_33 = tl.load(p_A_33, boundary_check=(0, 1)).to(tl.float32) + b_Ai_44 = tl.load(p_A_44, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16]) + b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32) + b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32) + b_Ai_33 = desc.load([i_t * BT + 32, 32]).to(tl.float32) + b_Ai_44 = desc.load([i_t * BT + 48, 48]).to(tl.float32) - A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) - A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) - A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) - A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) - A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) - A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) + # [16, 16] + b_Ai_11 = -tl.where(m_A, b_Ai_11, 0) + b_Ai_22 = -tl.where(m_A, b_Ai_22, 0) + b_Ai_33 = -tl.where(m_A, b_Ai_33, 0) + b_Ai_44 = -tl.where(m_A, b_Ai_44, 0) - Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) - Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) - Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32) - Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32) + for i in range(2, min(16, T - i_t * BT)): + b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i) + b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0) + b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11) + for i in range(16 + 2, min(32, T - i_t * BT)): + b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16) + b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0) + b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22) + for i in range(32 + 2, min(48, T - i_t * BT)): + b_a_33 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 32) + b_a_33 += tl.sum(b_a_33[:, None] * b_Ai_33, 0) + b_Ai_33 = tl.where((o_i == i - 32)[:, None], b_a_33, b_Ai_33) + for i in range(48 + 2, min(64, T - i_t * BT)): + b_a_44 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 48) + b_a_44 += tl.sum(b_a_44[:, None] * b_Ai_44, 0) + b_Ai_44 = tl.where((o_i == i - 48)[:, None], b_a_44, b_Ai_44) + b_Ai_11 += m_I + b_Ai_22 += m_I + b_Ai_33 += m_I + b_Ai_44 += m_I - Ai_21 = -tl.dot( - tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee" - ) - Ai_32 = -tl.dot( - tl.dot(Ai_33, A_32, input_precision="ieee"), Ai_22, input_precision="ieee" - ) - Ai_43 = -tl.dot( - tl.dot(Ai_44, A_43, input_precision="ieee"), Ai_33, input_precision="ieee" - ) + if not USE_TMA: + p_A_21 = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0) + ) + p_A_31 = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0) + ) + p_A_32 = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0) + ) + p_A_41 = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0) + ) + p_A_42 = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0) + ) + p_A_43 = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0) + ) + b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + b_A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) + b_A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) + b_A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) + b_A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) + b_A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) + else: + b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32) + b_A_31 = desc.load([i_t * BT + 32, 0]).to(tl.float32) + b_A_32 = desc.load([i_t * BT + 32, 16]).to(tl.float32) + b_A_41 = desc.load([i_t * BT + 48, 0]).to(tl.float32) + b_A_42 = desc.load([i_t * BT + 48, 16]).to(tl.float32) + b_A_43 = desc.load([i_t * BT + 48, 32]).to(tl.float32) - Ai_31 = -tl.dot( - Ai_33, - tl.dot(A_31, Ai_11, input_precision="ieee") - + tl.dot(A_32, Ai_21, input_precision="ieee"), - input_precision="ieee", + b_Ai_21 = -tl.dot( + tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), + b_Ai_11, + input_precision=DOT_PRECISION, ) - Ai_42 = -tl.dot( - Ai_44, - tl.dot(A_42, Ai_22, input_precision="ieee") - + tl.dot(A_43, Ai_32, input_precision="ieee"), - input_precision="ieee", + b_Ai_32 = -tl.dot( + tl.dot(b_Ai_33, b_A_32, input_precision=DOT_PRECISION), + b_Ai_22, + input_precision=DOT_PRECISION, ) - Ai_41 = -tl.dot( - Ai_44, - tl.dot(A_41, Ai_11, input_precision="ieee") - + tl.dot(A_42, Ai_21, input_precision="ieee") - + tl.dot(A_43, Ai_31, input_precision="ieee"), - input_precision="ieee", + b_Ai_43 = -tl.dot( + tl.dot(b_Ai_44, b_A_43, input_precision=DOT_PRECISION), + b_Ai_33, + input_precision=DOT_PRECISION, ) - p_Ai_11 = tl.make_block_ptr( - Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (16, 16), (1, 0) - ) - p_Ai_22 = tl.make_block_ptr( - Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0) - ) - p_Ai_33 = tl.make_block_ptr( - Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0) - ) - p_Ai_44 = tl.make_block_ptr( - Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0) - ) - p_Ai_21 = tl.make_block_ptr( - Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0) - ) - p_Ai_31 = tl.make_block_ptr( - Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0) - ) - p_Ai_32 = tl.make_block_ptr( - Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0) - ) - p_Ai_41 = tl.make_block_ptr( - Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0) - ) - p_Ai_42 = tl.make_block_ptr( - Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0) - ) - p_Ai_43 = tl.make_block_ptr( - Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0) - ) - tl.store( - p_Ai_11, - Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), - ) - tl.store( - p_Ai_22, - Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), - ) - tl.store( - p_Ai_33, - Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), - ) - tl.store( - p_Ai_44, - Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), - ) - tl.store( - p_Ai_21, - Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), - ) - tl.store( - p_Ai_31, - Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), - ) - tl.store( - p_Ai_32, - Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), - ) - tl.store( - p_Ai_41, - Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), - ) - tl.store( - p_Ai_42, - Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), - ) - tl.store( - p_Ai_43, - Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), + b_Ai_31 = -tl.dot( + b_Ai_33, + tl.dot(b_A_31, b_Ai_11, input_precision=DOT_PRECISION) + + tl.dot(b_A_32, b_Ai_21, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + b_Ai_42 = -tl.dot( + b_Ai_44, + tl.dot(b_A_42, b_Ai_22, input_precision=DOT_PRECISION) + + tl.dot(b_A_43, b_Ai_32, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + b_Ai_41 = -tl.dot( + b_Ai_44, + tl.dot(b_A_41, b_Ai_11, input_precision=DOT_PRECISION) + + tl.dot(b_A_42, b_Ai_21, input_precision=DOT_PRECISION) + + tl.dot(b_A_43, b_Ai_31, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, ) - fill_zeros = tl.zeros((16, 16), dtype=tl.float32) - p_Ai_12 = tl.make_block_ptr( - Ai, (T, 64), (H * 64, 1), (i_t * 64, 16), (16, 16), (1, 0) - ) - p_Ai_13 = tl.make_block_ptr( - Ai, (T, 64), (H * 64, 1), (i_t * 64, 32), (16, 16), (1, 0) - ) - p_Ai_14 = tl.make_block_ptr( - Ai, (T, 64), (H * 64, 1), (i_t * 64, 48), (16, 16), (1, 0) - ) - p_Ai_23 = tl.make_block_ptr( - Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32), (16, 16), (1, 0) - ) - p_Ai_24 = tl.make_block_ptr( - Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48), (16, 16), (1, 0) - ) - p_Ai_34 = tl.make_block_ptr( - Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48), (16, 16), (1, 0) - ) - tl.store( - p_Ai_12, - fill_zeros.to(p_Ai_12.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), - ) - tl.store( - p_Ai_13, - fill_zeros.to(p_Ai_13.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), - ) - tl.store( - p_Ai_14, - fill_zeros.to(p_Ai_14.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), - ) - tl.store( - p_Ai_23, - fill_zeros.to(p_Ai_23.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), - ) - tl.store( - p_Ai_24, - fill_zeros.to(p_Ai_24.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), - ) - tl.store( - p_Ai_34, - fill_zeros.to(p_Ai_34.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1), - ) + if not USE_TMA: + p_Ai_11 = tl.make_block_ptr( + Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0) + ) + p_Ai_22 = tl.make_block_ptr( + Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0) + ) + p_Ai_33 = tl.make_block_ptr( + Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0) + ) + p_Ai_44 = tl.make_block_ptr( + Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0) + ) + p_Ai_21 = tl.make_block_ptr( + Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0) + ) + p_Ai_31 = tl.make_block_ptr( + Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0) + ) + p_Ai_32 = tl.make_block_ptr( + Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0) + ) + p_Ai_41 = tl.make_block_ptr( + Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0) + ) + p_Ai_42 = tl.make_block_ptr( + Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0) + ) + p_Ai_43 = tl.make_block_ptr( + Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0) + ) + tl.store( + p_Ai_11, + b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_33, + b_Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_44, + b_Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_31, + b_Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_32, + b_Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_41, + b_Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_42, + b_Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_43, + b_Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + else: + desc_o.store( + [i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne") + ) + desc_o.store( + [i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne") + ) + desc_o.store( + [i_t * BT + 32, 32], b_Ai_33.to(desc_o.dtype, fp_downcast_rounding="rtne") + ) + desc_o.store( + [i_t * BT + 48, 48], b_Ai_44.to(desc_o.dtype, fp_downcast_rounding="rtne") + ) + desc_o.store( + [i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne") + ) + desc_o.store( + [i_t * BT + 32, 0], b_Ai_31.to(desc_o.dtype, fp_downcast_rounding="rtne") + ) + desc_o.store( + [i_t * BT + 32, 16], b_Ai_32.to(desc_o.dtype, fp_downcast_rounding="rtne") + ) + desc_o.store( + [i_t * BT + 48, 0], b_Ai_41.to(desc_o.dtype, fp_downcast_rounding="rtne") + ) + desc_o.store( + [i_t * BT + 48, 16], b_Ai_42.to(desc_o.dtype, fp_downcast_rounding="rtne") + ) + desc_o.store( + [i_t * BT + 48, 32], b_Ai_43.to(desc_o.dtype, fp_downcast_rounding="rtne") + ) @input_guard def solve_tril( A: torch.Tensor, - cu_seqlens: Optional[torch.Tensor] = None, + cu_seqlens: torch.Tensor | None = None, output_dtype: torch.dtype = torch.float, ) -> torch.Tensor: """ - Compute the inverse of the lower triangular matrix + Compute the inverse of the matrix I + A A should be strictly lower triangular, i.e., A.triu() == 0. Args: A (torch.Tensor): - [B, T, H, K] + [B, T, H, BT], where BT should only be 16, 32, or 64. cu_seqlens (torch.Tensor): - The cumulative sequence lengths of the input tensor. - Default: None. + The cumulative sequence lengths of the input tensor. Default: `None`. output_dtype (torch.dtype): - The dtype of the output tensor. Default: `torch.float` + The dtype of the output tensor. Default: `torch.float`. + If `None`, the output dtype will be the same as the input dtype. Returns: (I + A)^-1 with the same shape as A """ assert A.shape[-1] in [16, 32, 64] + output_dtype = A.dtype if output_dtype is None else output_dtype B, T, H, BT = A.shape - Ad = torch.empty( - B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype - ) - - chunk_indices = ( - prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None - ) - NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16) - solve_tril_16x16_kernel[NT, B * H]( - A=A, - Ad=Ad, - cu_seqlens=cu_seqlens, - chunk_indices=chunk_indices, - T=T, - H=H, - BT=BT, - ) - if BT == 16: - return Ad - - Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype) - merge_fn = ( - merge_16x16_to_32x32_inverse_kernel - if BT == 32 - else merge_16x16_to_64x64_inverse_kernel - ) chunk_indices = ( prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None ) NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) + + Ai = torch.zeros_like(A, dtype=output_dtype) + if BT == 16: + merge_fn = solve_tril_16x16_kernel + elif BT == 32: + merge_fn = merge_16x16_to_32x32_inverse_kernel + elif BT == 64: + merge_fn = merge_16x16_to_64x64_inverse_kernel + merge_fn[NT, B * H]( A=A, - Ad=Ad, Ai=Ai, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, T=T, H=H, BT=BT, + USE_TMA=is_tma_supported, + DOT_PRECISION=FLA_TRIL_PRECISION, ) return Ai diff --git a/vllm/model_executor/layers/fla/ops/utils.py b/vllm/model_executor/layers/fla/ops/utils.py index 07124f33f1e6..3a503981a873 100644 --- a/vllm/model_executor/layers/fla/ops/utils.py +++ b/vllm/model_executor/layers/fla/ops/utils.py @@ -11,8 +11,9 @@ import functools import logging import os +from collections.abc import Callable from enum import Enum -from typing import Any, Callable, Literal, Optional +from typing import Any, Literal import torch @@ -43,8 +44,8 @@ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor] A wrapped version of the input function with single-entry caching. """ - cache_entries: tuple[Optional[tuple], Optional[dict], Any] = [] - cache_size = 4 + cache_entries: tuple[tuple | None, dict | None, Any] = [] + cache_size = 8 @functools.wraps(fn) def wrapper(*args: Any, **kwargs: Any) -> Any: @@ -149,6 +150,11 @@ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: or torch.cuda.get_device_capability()[0] >= 9 ) use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" +is_gather_supported = hasattr(triton.language, "gather") +is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and ( + hasattr(triton.language, "_experimental_make_tensor_descriptor") + or hasattr(triton.language, "make_tensor_descriptor") +) def get_all_max_shared_mem(): diff --git a/vllm/model_executor/layers/fla/ops/wy_fast.py b/vllm/model_executor/layers/fla/ops/wy_fast.py index b628a90e843f..a66ec1d60d66 100644 --- a/vllm/model_executor/layers/fla/ops/wy_fast.py +++ b/vllm/model_executor/layers/fla/ops/wy_fast.py @@ -8,7 +8,6 @@ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 -from typing import Optional import torch @@ -123,7 +122,7 @@ def recompute_w_u_fwd( beta: torch.Tensor, g_cumsum: torch.Tensor, A: torch.Tensor, - cu_seqlens: Optional[torch.LongTensor], + cu_seqlens: torch.LongTensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: B, T, Hg, K, V = *k.shape, v.shape[-1] H = v.shape[-2] diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 56ffaf861ac7..cb31045971bd 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager -from typing import Any, Optional +from typing import Any from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.layer import ( @@ -15,10 +15,11 @@ FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, ) +from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe.utils import activation_without_mul from vllm.triton_utils import HAS_TRITON -_config: Optional[dict[str, Any]] = None +_config: dict[str, Any] | None = None @contextmanager @@ -30,7 +31,7 @@ def override_config(config): _config = old_config -def get_config() -> Optional[dict[str, Any]]: +def get_config() -> dict[str, Any] | None: return _config @@ -42,6 +43,7 @@ def get_config() -> Optional[dict[str, Any]]: "FusedMoEPermuteExpertsUnpermute", "FusedMoEActivationFormat", "FusedMoEPrepareAndFinalize", + "SharedFusedMoE", "activation_without_mul", "override_config", "get_config", @@ -49,7 +51,6 @@ def get_config() -> Optional[dict[str, Any]]: if HAS_TRITON: # import to register the custom ops - import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts, ) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index f30ebec76c67..095ec966ea7e 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -1,21 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from math import log2 -from typing import Optional import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, ) from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked, is_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import ( + fp8_m_grouped_gemm_nt_masked, + get_mk_alignment_for_contiguous_layout, + is_deep_gemm_e8m0_used, +) logger = init_logger(__name__) @@ -94,7 +95,7 @@ def _silu_mul_fp8_quant_deep_gemm( tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) -def silu_mul_fp8_quant_deep_gemm_cuda( +def persistent_masked_m_silu_mul_quant( y: torch.Tensor, # (E, T, 2*H) tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert num_parallel_tokens=16, @@ -103,9 +104,41 @@ def silu_mul_fp8_quant_deep_gemm_cuda( """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales y has shape (E, T, 2*H). The first half of the last dimension is silu-activated, multiplied by the second half, then quantized into FP8. + We launch a fixed grid of threads to accommodate CUDA graphs. Let `P2` + be a parallelization factor for persistent_masked_m_silu_mul_quant over the + hidden dimension. + + Let `expert_offsets = [0] + [num_tokens.cumsum()]` and + `total_tokens = expert_offsets[-1]`. + persistent_masked_m_silu_mul_quant launches `total_tokens x P2` number of + thread blocks. Each thread block contains `NUM_WARPS` warps. + + Every thread block needs to find it's corresponding expert by warp-parallel scanning + over the `expert_offsets` array. + + The i-th warp in the first thread block processes + `[i * warp_chunk_size, (i + 1) * warp_chunk_size]` groups + sequentially, where `warp_chunk_size = ((H / GROUP_SIZE) / P2) / NUM_WARPS`, + pipelining loads and computes. + + The shared memory layout for 4 warps with a 2-stage pipeline for SiLU V2 + can is visualized like so: + + stage0 stage1 + ┌─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┐ + │gate0│up0│gate1│up1│gate2│up2│gate3│up3│gate0│up0│gate1│up1│gate2│up2│gate3│up3│ + └─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┘ + + with the main difference between V1 and V2 being the global load + stride between warps, and between half-warps. Regarding the latter stride, + we assign the first half warp of every warp for `gate` loads and the second + half-warp to `up` loads. + Returns `(y_q, y_s)` where * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) + Let NUM_WARPS be the number of warps in a single thread block and + `GROUP_SIZE = 128` be the size of the quantization group. """ assert y.ndim == 3, "y must be (E, T, 2*H)" E, T, H2 = y.shape @@ -133,30 +166,15 @@ def silu_mul_fp8_quant_deep_gemm_cuda( use_ue8m0 = is_deep_gemm_e8m0_used() - if E <= 16: - max_empirical_parallelism = 64 - elif E <= 32: - max_empirical_parallelism = 16 - else: - max_empirical_parallelism = 4 - - # We never want to launch more than Tx number of threads - # This computes the clip. - num_parallel_tokens = max( - 1, min(max_empirical_parallelism, 2 ** int(log2(min(num_parallel_tokens, T)))) - ) cuda_arch = current_platform.get_device_capability( device_id=y.device.index ).to_int() if cuda_arch >= 80: - torch.ops._C.silu_mul_fp8_quant_deep_gemm_cuda( - y, tokens_per_expert, y_q, y_s, group_size, use_ue8m0, num_parallel_tokens + torch.ops._C.persistent_masked_m_silu_mul_quant( + y, tokens_per_expert, y_q, y_s, use_ue8m0 ) else: - # Default to triton if not on cuda or if arch is too old - y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) - stride_cnt_e = tokens_per_expert.stride()[0] # Static grid over experts and H-groups. @@ -166,16 +184,6 @@ def silu_mul_fp8_quant_deep_gemm_cuda( stride_i_e, stride_i_t, stride_i_h = y.stride() stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() - # desired scale strides (elements): (T*G, 1, T) - stride_ys_e = T * G - stride_ys_t = 1 - stride_ys_g = T - y_s = torch.empty_strided( - (E, T, G), - (stride_ys_e, stride_ys_t, stride_ys_g), - dtype=torch.float32, - device=y.device, - ) f_info = torch.finfo(fp8_dtype) fp8_max = f_info.max fp8_min = f_info.min @@ -222,7 +230,7 @@ def __init__( quant_config: Quantization configuration """ super().__init__(quant_config) - assert self.block_shape == deep_gemm_block_shape() + assert self.block_shape == get_mk_alignment_for_contiguous_layout() self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers @@ -247,29 +255,24 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_metadata: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - assert a.dim() == 2 + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # FIXME (varun): We should be able to dispatch only from the leader # DP ranks in the case of TP > 1. At the moment, all the Ranks # end up sending their tokens. This needs to be fixed. num_dispatchers = self.num_dispatchers num_experts = local_num_experts - max_num_tokens = ( - a.size(0) if self.max_num_tokens is None else self.max_num_tokens - ) + max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N)) workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2)) output = (num_experts, max_num_tokens * num_dispatchers, K) - return (workspace13, workspace2, output, a.dtype) + return (workspace13, workspace2, output) def apply( self, @@ -281,12 +284,12 @@ def apply( topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): assert expert_tokens_meta is not None @@ -300,7 +303,7 @@ def apply( assert w2.size(1) == K - E, max_num_tokens, N, K, top_k_num = self.moe_problem_size( + E, max_num_tokens, N, K, _ = self.moe_problem_size( hidden_states, w1, w2, topk_ids ) @@ -318,7 +321,7 @@ def apply( expected_m, ) - a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda( + a2q, a2q_scale = persistent_masked_m_silu_mul_quant( workspace1, expert_num_tokens ) diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index d268f70477f4..e69e9fd307ae 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -9,8 +8,8 @@ BatchedDeepGemmExperts, ) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts +from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -32,7 +31,7 @@ def __init__( self.allow_deep_gemm = ( allow_deep_gemm and self.quant_config.use_fp8_w8a8 - and self.block_shape == deep_gemm_block_shape() + and self.block_shape == get_mk_alignment_for_contiguous_layout() ) self.batched_deep_gemm_experts = ( @@ -99,26 +98,25 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: assert bte_war is not None return bte_war + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: + return act_dtype + def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_metadata: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + expert_tokens_metadata: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. if self.allow_deep_gemm: assert self.batched_deep_gemm_experts is not None return self.batched_deep_gemm_experts.workspace_shapes( - a, - aq, M, N, K, @@ -130,8 +128,6 @@ def workspace_shapes( else: assert self.batched_triton_experts is not None return self.batched_triton_experts.workspace_shapes( - a, - aq, M, N, K, @@ -151,12 +147,12 @@ def apply( topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): experts = ( diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 5780c969d273..200212dfb42b 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -14,8 +14,9 @@ OCP_MX_Scheme, ) from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape -from vllm.utils import cdiv, has_triton_kernels +from vllm.utils import cdiv from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +from vllm.utils.import_utils import has_triton_kernels logger = init_logger(__name__) @@ -34,8 +35,8 @@ def _get_config_dtype_str( use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - ocp_mx_scheme: Optional[str] = None, -) -> Optional[str]: + ocp_mx_scheme: str | None = None, +) -> str | None: """ Return a string used to construct the filename that contains the tuning info for a particular quantization scheme. See @@ -60,16 +61,16 @@ def _get_config_dtype_str( def _quant_flags_to_group_shape( - quant_dtype: Union[torch.dtype, str, None], + quant_dtype: torch.dtype | str | None, per_act_token_quant: bool, per_out_ch_quant: bool, - block_shape: Optional[list[int]], -) -> tuple[Optional[GroupShape], Optional[GroupShape]]: + block_shape: list[int] | None, +) -> tuple[GroupShape | None, GroupShape | None]: """ Convert MoE quantization flags into more generic GroupShapes. """ - a_shape: Optional[GroupShape] - w_shape: Optional[GroupShape] + a_shape: GroupShape | None + w_shape: GroupShape | None if block_shape is not None: assert not per_act_token_quant assert not per_out_ch_quant @@ -100,7 +101,7 @@ class FusedMoEQuantDesc: # The quantized type of this parameters. None means unquantized or # already quantized. # TODO (bnell): use scalar_type instead of Union. - dtype: Union[torch.dtype, str, None] = None + dtype: torch.dtype | str | None = None # A field that describes the quantization group shape, from quant_utils.py. # * (-1, -1) for per-tensor quantization @@ -109,7 +110,7 @@ class FusedMoEQuantDesc: # * (128, 128) for 128x128 deepseek style block quantization # * (1, 128) for deepseek style activation quantization # (i.e. per-token-per-group) - shape: Optional[GroupShape] = None + shape: GroupShape | None = None # Quantization scales. # TODO(bnell): maybe put PrecisionConfigs in subclass of QuantDesc? @@ -117,13 +118,13 @@ class FusedMoEQuantDesc: # Quantization alphas or gscales, used for nvfp4 types. # TODO(bnell): put some of these in subclasses - alpha_or_gscale: Optional[torch.Tensor] = None + alpha_or_gscale: torch.Tensor | None = None # Zero points for int4/int8 types - zp: Optional[torch.Tensor] = None + zp: torch.Tensor | None = None # Biases for GPT triton MoE - bias: Optional[torch.Tensor] = None + bias: torch.Tensor | None = None # TODO(bnell): have subclasses for specific moe methods? @@ -179,7 +180,7 @@ def __post_init__(self): # @property - def quant_dtype(self) -> Union[torch.dtype, str, None]: + def quant_dtype(self) -> torch.dtype | str | None: return self._a1.dtype @property @@ -203,7 +204,7 @@ def is_per_tensor(self) -> bool: return self._a1.shape == GroupShape.PER_TENSOR @property - def block_shape(self) -> Optional[list[int]]: + def block_shape(self) -> list[int] | None: if ( self._a1.shape is not None and self._a1.shape != GroupShape.PER_TENSOR @@ -218,34 +219,34 @@ def is_block_quantized(self) -> bool: return self.block_shape is not None @property - def a1_scale(self) -> Optional[torch.Tensor]: + def a1_scale(self) -> torch.Tensor | None: assert self._a1.scale is None or isinstance(self._a1.scale, torch.Tensor) return self._a1.scale @property - def a1_gscale(self) -> Optional[torch.Tensor]: + def a1_gscale(self) -> torch.Tensor | None: return self._a1.alpha_or_gscale @property - def a2_scale(self) -> Optional[torch.Tensor]: + def a2_scale(self) -> torch.Tensor | None: assert self._a2.scale is None or isinstance(self._a2.scale, torch.Tensor) return self._a2.scale @property - def a2_gscale(self) -> Optional[torch.Tensor]: + def a2_gscale(self) -> torch.Tensor | None: return self._a2.alpha_or_gscale @property - def w1_scale(self) -> Optional[torch.Tensor]: + def w1_scale(self) -> torch.Tensor | None: assert self._w1.scale is None or isinstance(self._w1.scale, torch.Tensor) return self._w1.scale @property - def w1_zp(self) -> Optional[torch.Tensor]: + def w1_zp(self) -> torch.Tensor | None: return self._w1.zp @property - def w1_bias(self) -> Optional[torch.Tensor]: + def w1_bias(self) -> torch.Tensor | None: return self._w1.bias @property @@ -254,20 +255,20 @@ def w1_precision(self) -> Optional["PrecisionConfig"]: return self._w1.scale @property - def g1_alphas(self) -> Optional[torch.Tensor]: + def g1_alphas(self) -> torch.Tensor | None: return self._w1.alpha_or_gscale @property - def w2_scale(self) -> Optional[torch.Tensor]: + def w2_scale(self) -> torch.Tensor | None: assert self._w2.scale is None or isinstance(self._w2.scale, torch.Tensor) return self._w2.scale @property - def w2_zp(self) -> Optional[torch.Tensor]: + def w2_zp(self) -> torch.Tensor | None: return self._w2.zp @property - def w2_bias(self) -> Optional[torch.Tensor]: + def w2_bias(self) -> torch.Tensor | None: return self._w2.bias @property @@ -276,7 +277,7 @@ def w2_precision(self) -> Optional["PrecisionConfig"]: return self._w2.scale @property - def g2_alphas(self) -> Optional[torch.Tensor]: + def g2_alphas(self) -> torch.Tensor | None: return self._w2.alpha_or_gscale @property @@ -296,7 +297,7 @@ def use_int4_w4a16(self) -> bool: return self._a1.dtype is None and self._w1.dtype == "int4" @property - def ocp_mx_scheme(self) -> Union[str, None]: + def ocp_mx_scheme(self) -> str | None: if not hasattr(self, "_ocp_mx_scheme"): if (self._a1.dtype is not None and not isinstance(self._a1.dtype, str)) or ( self._w1.dtype is not None and not isinstance(self._w1.dtype, str) @@ -322,7 +323,7 @@ def use_mxfp4_w4a16(self) -> bool: def use_nvfp4_w4a4(self) -> bool: return self.quant_dtype == "nvfp4" - def config_name(self, dtype: torch.dtype) -> Optional[str]: + def config_name(self, dtype: torch.dtype) -> str | None: """ Return a string used to construct the filename that contains the tuning info for a particular quantization scheme. See @@ -340,7 +341,7 @@ def scale_shape( self, max_tokens: int, hidden_dim: int, - ) -> Optional[tuple[int, int]]: + ) -> tuple[int, int] | None: """ Construct the proper activation scale shape for this config. @@ -363,7 +364,7 @@ def batched_scale_shape( num_experts: int, max_tokens: int, hidden_dim: int, - ) -> Optional[tuple[int, int, int]]: + ) -> tuple[int, int, int] | None: """ Construct the proper activation batched scale shape for this config, e.g. (num experts, *scale_shape). @@ -377,23 +378,23 @@ def batched_scale_shape( @staticmethod def make( - quant_dtype: Union[torch.dtype, str, None] = None, + quant_dtype: torch.dtype | str | None = None, per_act_token_quant: bool = False, per_out_ch_quant: bool = False, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, w1_scale: Union[torch.Tensor, "PrecisionConfig", None] = None, w2_scale: Union[torch.Tensor, "PrecisionConfig", None] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - g1_alphas: Optional[torch.Tensor] = None, - g2_alphas: Optional[torch.Tensor] = None, - a1_gscale: Optional[torch.Tensor] = None, - a2_gscale: Optional[torch.Tensor] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - weight_dtype: Union[torch.dtype, str, None] = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + g1_alphas: torch.Tensor | None = None, + g2_alphas: torch.Tensor | None = None, + a1_gscale: torch.Tensor | None = None, + a2_gscale: torch.Tensor | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, + w1_zp: torch.Tensor | None = None, + w2_zp: torch.Tensor | None = None, + weight_dtype: torch.dtype | str | None = None, ) -> "FusedMoEQuantConfig": """ General builder function for a FusedMoEQuantConfig. @@ -457,11 +458,11 @@ def make( def fp8_w8a8_moe_quant_config( w1_scale: torch.Tensor, w2_scale: torch.Tensor, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, per_act_token_quant: bool = False, per_out_ch_quant: bool = False, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for fp8 activations and fp8 weights. @@ -481,8 +482,8 @@ def fp8_w8a8_moe_quant_config( def int8_w8a8_moe_quant_config( w1_scale: torch.Tensor, w2_scale: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + a1_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, per_act_token_quant: bool = False, ) -> FusedMoEQuantConfig: """ @@ -503,8 +504,8 @@ def int8_w8a8_moe_quant_config( def mxfp4_w4a16_moe_quant_config( w1_scale: Union[torch.Tensor, "PrecisionConfig"], w2_scale: Union[torch.Tensor, "PrecisionConfig"], - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for unquantized activations and mxfp4 weights. @@ -517,16 +518,36 @@ def mxfp4_w4a16_moe_quant_config( ) +def mxfp4_mxfp8_moe_quant_config( + w1_scale: Union[torch.Tensor, "PrecisionConfig"], + w2_scale: Union[torch.Tensor, "PrecisionConfig"], + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, + block_shape: list[int] | None = None, +) -> FusedMoEQuantConfig: + """ + Construct a quant config for mxfp4 activations and mxfp4 weights. + """ + return FusedMoEQuantConfig( + _a1=FusedMoEQuantDesc("mxfp8"), + _a2=FusedMoEQuantDesc("mxfp8"), + _w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias), + _w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias), + ) + + def ocp_mx_moe_quant_config( quant_dtype: str, w1_scale: Union[torch.Tensor, "PrecisionConfig"], w2_scale: Union[torch.Tensor, "PrecisionConfig"], - weight_dtype: Optional[str] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, + weight_dtype: str | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, + block_shape: list[int] | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for mxfp4 activations and mxfp4 weights. @@ -575,9 +596,9 @@ def nvfp4_moe_quant_config( def int4_w4a16_moe_quant_config( w1_scale: torch.Tensor, w2_scale: torch.Tensor, - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - block_shape: Optional[list[int]] = None, + w1_zp: torch.Tensor | None, + w2_zp: torch.Tensor | None, + block_shape: list[int] | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for 16-bit float activations and int4 weights. @@ -595,9 +616,9 @@ def int4_w4a16_moe_quant_config( def int8_w8a16_moe_quant_config( w1_scale: torch.Tensor, w2_scale: torch.Tensor, - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - block_shape: Optional[list[int]] = None, + w1_zp: torch.Tensor | None, + w2_zp: torch.Tensor | None, + block_shape: list[int] | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for 16-bit float activations and int8 weights. @@ -613,8 +634,8 @@ def int8_w8a16_moe_quant_config( def biased_moe_quant_config( - w1_bias: Optional[torch.Tensor], - w2_bias: Optional[torch.Tensor], + w1_bias: torch.Tensor | None, + w2_bias: torch.Tensor | None, ) -> FusedMoEQuantConfig: """ Construct a quant config for unquantized activations with biases. @@ -641,6 +662,7 @@ class FusedMoEParallelConfig: ep_rank: int use_ep: bool # whether to use EP or not + all2all_backend: str # all2all backend for MoE communication @property def use_all2all_kernels(self): @@ -648,21 +670,29 @@ def use_all2all_kernels(self): @property def use_pplx_kernels(self): - return self.use_all2all_kernels and envs.VLLM_ALL2ALL_BACKEND == "pplx" + return self.use_all2all_kernels and self.all2all_backend == "pplx" @property def use_deepep_ht_kernels(self): return ( self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" + and self.all2all_backend == "deepep_high_throughput" ) @property def use_deepep_ll_kernels(self): - return ( - self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" - ) + return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" + + @staticmethod + def flatten_tp_across_dp( + tp_size: int, dp_size: int, dp_rank: int + ) -> tuple[int, int]: + tp_rank = 0 if tp_size == 1 else get_tensor_model_parallel_rank() + # There are actually dp_size * tp_size devices. Update tp_size + # and tp_rank so we shard across all devices. + flatten_tp_size = dp_size * tp_size + flatten_tp_rank = dp_rank * tp_size + tp_rank + return flatten_tp_size, flatten_tp_rank @staticmethod def make( @@ -739,19 +769,13 @@ def make( between the 4 devices. """ - def flatten_tp_across_dp(dp_rank: int): - tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank() - # There are actually dp_size_ * tp_size_ devices. Update tp_size - # and tp_rank so we shard across all devices. - tp_size = dp_size_ * tp_size_ - tp_rank = dp_rank * tp_size_ + tp_rank - return tp_size, tp_rank - use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel dp_size = dp_size_ dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 - tp_size, tp_rank = flatten_tp_across_dp(dp_rank) + tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp( + tp_size_, dp_size_, dp_rank + ) if not use_ep: return FusedMoEParallelConfig( @@ -762,6 +786,7 @@ def flatten_tp_across_dp(dp_rank: int): ep_size=1, ep_rank=0, use_ep=False, + all2all_backend=vllm_parallel_config.all2all_backend, ) # DP + EP / TP + EP / DP + TP + EP assert use_ep @@ -777,6 +802,7 @@ def flatten_tp_across_dp(dp_rank: int): ep_size=ep_size, ep_rank=ep_rank, use_ep=True, + all2all_backend=vllm_parallel_config.all2all_backend, ) @@ -797,6 +823,8 @@ class FusedMoEConfig: has_bias: bool = False + is_act_and_mul: bool = True + def __post_init__(self): if self.dp_size > 1: logger.debug_once( diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..86b49127f9bf --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=NVIDIA_H200.json new file mode 100644 index 000000000000..ea1ce9ad2cdc --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=bf16.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=bf16.json new file mode 100644 index 000000000000..d613de3a754f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=bf16.json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..592b60c5acea --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..2a626ac47b8d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H200.json new file mode 100644 index 000000000000..371e87f94682 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..8b94452197b0 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=2048,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=2048,device_name=NVIDIA_H200.json new file mode 100644 index 000000000000..48f19df24cc9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=2048,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=AMD_Instinct_MI350_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=AMD_Instinct_MI350_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..eb4d11c6be2b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=AMD_Instinct_MI350_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=384,device_name=AMD_Instinct_MI350_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=384,device_name=AMD_Instinct_MI350_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..c2f79b966abb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=384,device_name=AMD_Instinct_MI350_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=384,device_name=AMD_Instinct_MI355_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=384,device_name=AMD_Instinct_MI355_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..c1ca10063189 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=384,device_name=AMD_Instinct_MI355_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=32,N=1408,device_name=NVIDIA_B200.json b/vllm/model_executor/layers/fused_moe/configs/E=32,N=1408,device_name=NVIDIA_B200.json new file mode 100644 index 000000000000..8ed3ad352717 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=32,N=1408,device_name=NVIDIA_B200.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.4.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=40,N=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=40,N=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..7ffa2ac89487 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=40,N=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.4.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..b0bf1bf51785 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.4.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=NVIDIA_B200.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=NVIDIA_B200.json new file mode 100644 index 000000000000..9952f8083479 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=NVIDIA_B200.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.4.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H100_PCIe,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H100_PCIe,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..2c897dbce17e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H100_PCIe,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.4.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=bf16.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=bf16.json new file mode 100644 index 000000000000..fd675df5d564 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=bf16.json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..e410671b6fd4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index 3592a88b0ef2..552d9e9cf88f 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from collections.abc import Callable import torch from torch.nn import functional as F @@ -33,7 +33,7 @@ def grouped_topk( topk_group: int = 0, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" @@ -88,12 +88,12 @@ def select_experts( top_k: int, use_grouped_topk: bool, renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if use_grouped_topk: assert topk_group is not None @@ -147,14 +147,14 @@ def __call__( top_k: int, router_logits: torch.Tensor, renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: @@ -189,14 +189,14 @@ def __call__( top_k: int, router_logits: torch.Tensor, renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: @@ -247,14 +247,14 @@ def __call__( top_k: int, router_logits: torch.Tensor, renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d3fed9332958..6753a19250b3 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """CUTLASS based Fused MoE kernels.""" -from typing import Callable, Optional +from collections.abc import Callable import torch @@ -35,23 +35,23 @@ def run_cutlass_moe_fp8( topk_ids: torch.Tensor, activation_callable: Callable, global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + expert_map: torch.Tensor | None, + w1_scale: torch.Tensor | None, + w2_scale: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, ab_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides1: torch.Tensor, c_strides2: torch.Tensor, workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], + expert_num_tokens: torch.Tensor | None, out_dtype: torch.dtype, per_act_token: bool, per_out_ch: bool, use_batched_format: bool, - topk_weights: Optional[torch.Tensor], + topk_weights: torch.Tensor | None, ): a1q = hidden_states @@ -249,7 +249,7 @@ def run_cutlass_moe_fp8( class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - out_dtype: Optional[torch.dtype], + out_dtype: torch.dtype | None, ab_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides1: torch.Tensor, @@ -278,12 +278,12 @@ def apply( topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE" @@ -331,7 +331,7 @@ def apply( class CutlassExpertsFp8(CutlassExpertsFp8Base): def __init__( self, - out_dtype: Optional[torch.dtype], + out_dtype: torch.dtype | None, ab_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides1: torch.Tensor, @@ -366,27 +366,23 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # topk weights and reduction are fused in moe_unpermute cuda kernel return TopKWeightAndReduceNoOP() + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: + return self.out_dtype if self.out_dtype is not None else act_dtype + def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: workspace1 = (M * topk, max(N, K)) workspace2 = (M * topk, max(N // 2, K)) output = (M, K) - return ( - workspace1, - workspace2, - output, - self.out_dtype if self.out_dtype is not None else a.dtype, - ) + return (workspace1, workspace2, output) class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): @@ -394,7 +390,7 @@ def __init__( self, max_experts_per_worker: int, num_dispatchers: int, - out_dtype: Optional[torch.dtype], + out_dtype: torch.dtype | None, ab_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides1: torch.Tensor, @@ -428,31 +424,25 @@ def supports_chunking(self) -> bool: def supports_expert_map(self) -> bool: return False - # TODO(bnell): maybe remove need for passing aq to workspace_shapes + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: + return self.out_dtype if self.out_dtype is not None else act_dtype + def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - padded_M = aq.size(1) + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: num_dp = self.num_dispatchers assert num_dp is not None - workspace1 = (self.max_experts_per_worker, padded_M * num_dp, max(N, K)) - workspace2 = (self.max_experts_per_worker, padded_M * num_dp, max(N // 2, K)) - output = (self.max_experts_per_worker, padded_M, K) - return ( - workspace1, - workspace2, - output, - self.out_dtype if self.out_dtype is not None else a.dtype, - ) + workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K)) + workspace2 = (self.max_experts_per_worker, M * num_dp, max(N // 2, K)) + output = (self.max_experts_per_worker, M, K) + return (workspace1, workspace2, output) def cutlass_moe_fp8( @@ -467,7 +457,7 @@ def cutlass_moe_fp8( c_strides2: torch.Tensor, quant_config: FusedMoEQuantConfig, activation: str = "silu", - expert_map: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, ) -> torch.Tensor: @@ -521,13 +511,19 @@ def cutlass_moe_fp8( assert quant_config is not None if quant_config.a1_scale is not None: - assert quant_config.per_act_token_quant == quant_config.a1_scale.numel() != 1 + assert quant_config.per_act_token_quant == (quant_config.a1_scale.numel() != 1) if quant_config.a2_scale is not None: - assert quant_config.per_act_token_quant == quant_config.a2_scale.numel() != 1 + assert quant_config.per_act_token_quant == (quant_config.a2_scale.numel() != 1) - assert quant_config.w1_scale is None or ( - quant_config.per_out_ch_quant == (quant_config.w1_scale.size(1) == w1_q.size(1)) - ) + if quant_config.w1_scale is not None: + if quant_config.per_out_ch_quant: + assert quant_config.w1_scale.dim() > 1 and quant_config.w1_scale.size( + 1 + ) == w1_q.size(1) + else: + assert ( + quant_config.w1_scale.dim() == 1 or quant_config.w1_scale.size(1) == 1 + ) num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0) @@ -767,36 +763,31 @@ def supports_chunking(self) -> bool: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: + return self.out_dtype if self.out_dtype is not None else act_dtype + def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: workspace1: tuple[int, ...] = () workspace2: tuple[int, ...] = () output: tuple[int, ...] = () if self.use_batched_format: - padded_M = aq.size(1) - workspace1 = (self.max_experts_per_worker, padded_M, max(N, K)) - workspace2 = (self.max_experts_per_worker, padded_M, (N // 2)) - output = (self.max_experts_per_worker, padded_M, K) + workspace1 = (self.max_experts_per_worker, M, max(N, K)) + workspace2 = (self.max_experts_per_worker, M, (N // 2)) + output = (self.max_experts_per_worker, M, K) else: workspace1 = (M * topk, max(2 * N, K)) workspace2 = (M * topk, N) output = (M, K) - return ( - workspace1, - workspace2, - output, - self.out_dtype if self.out_dtype is not None else a.dtype, - ) + return (workspace1, workspace2, output) def apply( self, @@ -808,12 +799,12 @@ def apply( topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], # unused - a2_scale: Optional[torch.Tensor], # unused - workspace13: Optional[torch.Tensor], - workspace2: Optional[torch.Tensor], - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, # unused + a2_scale: torch.Tensor | None, # unused + workspace13: torch.Tensor | None, + workspace2: torch.Tensor | None, + expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids) @@ -854,7 +845,7 @@ def cutlass_moe_fp4( n: int, k: int, e: int, - expert_map: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, ) -> torch.Tensor: assert expert_map is None, ( @@ -911,7 +902,7 @@ def _valid_cutlass_block_scaled_grouped_gemm( inplace: bool, activation: str, apply_router_weight_on_input: bool, - expert_map: Optional[torch.Tensor], + expert_map: torch.Tensor | None, ) -> bool: def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int): return N % 128 == 0 and K % 128 == 0 diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index fec3a7c5d0a9..484b8aa9d107 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch from tqdm import tqdm @@ -14,7 +13,6 @@ ) from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( compute_aligned_M, - deep_gemm_block_shape, deepgemm_moe_permute, deepgemm_unpermute_and_reduce, ) @@ -28,14 +26,18 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) -from vllm.utils import has_deep_gemm, run_once -from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous +from vllm.utils.deep_gemm import ( + get_mk_alignment_for_contiguous_layout, + m_grouped_fp8_gemm_nt_contiguous, +) +from vllm.utils.func_utils import run_once +from vllm.utils.import_utils import has_deep_gemm logger = init_logger(__name__) def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool: - align = deep_gemm_block_shape()[0] + align = get_mk_alignment_for_contiguous_layout()[0] return align <= M and N % align == 0 and K % align == 0 @@ -54,7 +56,7 @@ def _valid_deep_gemm( M = hidden_states.size(0) _, K, N = w2.size() - align = deep_gemm_block_shape()[0] + align = get_mk_alignment_for_contiguous_layout()[0] if not _valid_deep_gemm_shape(M, N, K): logger.debug_once( @@ -124,7 +126,7 @@ def warmup_deepgemm_gg_contiguous_kernels( assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts" - block_m = deep_gemm_block_shape()[0] + block_m = get_mk_alignment_for_contiguous_layout()[0] num_experts = w1.size(0) device = w1.device @@ -173,7 +175,7 @@ def _warmup(w: torch.Tensor, w_scale: torch.Tensor): class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self, quant_config: FusedMoEQuantConfig): super().__init__(quant_config) - assert quant_config.block_shape == deep_gemm_block_shape() + assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout() assert quant_config.quant_dtype == torch.float8_e4m3fn assert not quant_config.per_act_token_quant assert not quant_config.per_out_ch_quant @@ -198,16 +200,14 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: assert self.block_shape is not None block_m = self.block_shape[0] M_sum = compute_aligned_M( @@ -218,7 +218,7 @@ def workspace_shapes( workspace1 = (M_sum, max(N, K)) workspace2 = (M_sum, max(N // 2, K)) output = (M, K) - return (workspace1, workspace2, output, a.dtype) + return (workspace1, workspace2, output) def apply( self, @@ -230,12 +230,12 @@ def apply( topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): assert a1q_scale is not None @@ -257,7 +257,7 @@ def apply( M=topk_ids.size(0), num_topk=topk_ids.size(1), local_num_experts=local_num_experts, - alignment=deep_gemm_block_shape()[0], + alignment=get_mk_alignment_for_contiguous_layout()[0], expert_tokens_meta=expert_tokens_meta, ) @@ -286,7 +286,7 @@ def apply( self.activation(activation, act_out, mm1_out.view(-1, N)) - a2q_scale: Optional[torch.Tensor] = None + a2q_scale: torch.Tensor | None = None a2q, a2q_scale = per_token_group_quant_fp8( act_out, self.block_shape[1], column_major_scales=True, out_q=quant_out ) @@ -319,9 +319,9 @@ def deep_gemm_moe_fp8( inplace: bool = False, activation: str = "silu", global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, apply_router_weight_on_input=False, ) -> torch.Tensor: """ @@ -366,7 +366,7 @@ def deep_gemm_moe_fp8( w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, - block_shape=deep_gemm_block_shape(), + block_shape=get_mk_alignment_for_contiguous_layout(), ) fn = mk.FusedMoEModularKernel( diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py b/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py index 2ac968a9b4ab..85294f6aea6e 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py @@ -5,24 +5,13 @@ and updated to fit vllm needs and terminology. """ -import functools -from typing import Optional - import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens from vllm.triton_utils import tl, triton from vllm.utils import round_up - - -@functools.cache -def deep_gemm_block_shape() -> list[int]: - # Lazy import to avoid CUDA initialization problems. - import deep_gemm as dg - - block = dg.get_m_alignment_for_contiguous_layout() - return [block, block] +from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout def expert_num_tokens_round_up_and_sum( @@ -39,7 +28,7 @@ def compute_aligned_M( num_topk: int, local_num_experts: int, alignment: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + expert_tokens_meta: mk.ExpertTokensMetadata | None, ): if (expert_tokens_meta is not None) and ( expert_tokens_meta.expert_num_tokens_cpu is not None @@ -175,7 +164,7 @@ def ep_scatter( recv_x_scale: torch.Tensor, recv_topk: torch.Tensor, num_recv_tokens_per_expert: torch.Tensor, - expert_map: Optional[torch.Tensor], + expert_map: torch.Tensor | None, expert_start_loc: torch.Tensor, output_tensor: torch.Tensor, output_tensor_scale: torch.Tensor, @@ -305,7 +294,7 @@ def ep_gather( recv_topk_ids: torch.Tensor, recv_topk_weight: torch.Tensor, input_index: torch.Tensor, - expert_map: Optional[torch.Tensor], + expert_map: torch.Tensor | None, output_tensor: torch.Tensor, ): num_warps = 2 @@ -346,17 +335,16 @@ def deepgemm_moe_permute( aq_scale: torch.Tensor, topk_ids: torch.Tensor, local_num_experts: int, - expert_map: Optional[torch.Tensor], - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - aq_out: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + aq_out: torch.Tensor | None = None, ): assert aq.ndim == 2 assert topk_ids.dtype.is_signed, "The kernel uses -1 to represent invalid topk_ids" H = aq.size(1) device = aq.device - block_m = deep_gemm_block_shape()[0] - block_k = deep_gemm_block_shape()[1] + block_m, block_k = get_mk_alignment_for_contiguous_layout() M_sum = compute_aligned_M( M=topk_ids.size(0), @@ -415,7 +403,7 @@ def deepgemm_unpermute_and_reduce( topk_ids: torch.Tensor, topk_weights: torch.Tensor, inv_perm: torch.Tensor, - expert_map: Optional[torch.Tensor], + expert_map: torch.Tensor | None, output: torch.Tensor, ): return ep_gather( diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 9a2844b7d998..a5c5c115f36c 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional, Union +from collections.abc import Callable import deep_ep import torch @@ -70,22 +70,25 @@ def __init__( def num_dispatchers(self) -> int: return self.num_dispatchers_ + def output_is_reduced(self) -> bool: + return True + @property def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard - def max_num_tokens_per_rank(self) -> Optional[int]: + def max_num_tokens_per_rank(self) -> int | None: return None - def topk_indices_dtype(self) -> Optional[torch.dtype]: + def topk_indices_dtype(self) -> torch.dtype | None: return torch.int64 - def _get_dispatch_config(self) -> Optional[deep_ep.Config]: + def _get_dispatch_config(self) -> deep_ep.Config | None: if self.num_dispatchers_ not in self.available_rank_configs: return None return deep_ep.Buffer.get_dispatch_config(self.num_dispatchers_) - def _get_combine_config(self) -> Optional[deep_ep.Config]: + def _get_combine_config(self) -> deep_ep.Config | None: if self.num_dispatchers_ not in self.available_rank_configs: return None return deep_ep.Buffer.get_combine_config(self.num_dispatchers_) @@ -93,11 +96,11 @@ def _get_combine_config(self) -> Optional[deep_ep.Config]: def _do_dispatch( self, tokens: torch.Tensor, - token_scales: Optional[torch.Tensor], + token_scales: torch.Tensor | None, rank_topk_ids: torch.Tensor, rank_topk_weights: torch.Tensor, num_experts: int, - a1_scale: Optional[torch.Tensor], + a1_scale: torch.Tensor | None, quant_config: FusedMoEQuantConfig, ) -> Callable: has_scales = token_scales is not None @@ -172,12 +175,12 @@ def _receiver( self, event: deep_ep.EventOverlap, has_scales: bool, - token_data: Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor], - expert_topk_ids: Optional[torch.Tensor], + token_data: tuple[torch.Tensor, torch.Tensor] | torch.Tensor, + expert_topk_ids: torch.Tensor | None, num_experts: int, expert_num_tokens_per_expert_list: list[int], - expert_topk_weights: Optional[torch.Tensor], - a1_scale: Optional[torch.Tensor], + expert_topk_weights: torch.Tensor | None, + a1_scale: torch.Tensor | None, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: if event.event is not None: @@ -246,7 +249,7 @@ def prepare_async( topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], + expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.ReceiverType: @@ -291,7 +294,7 @@ def prepare( topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], + expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: @@ -315,7 +318,7 @@ def _finalize( apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, do_async: bool, - ) -> Optional[Callable]: + ) -> Callable | None: a2a_idx = dbo_current_ubatch_id() handle = self.handles[a2a_idx] assert handle is not None @@ -333,7 +336,11 @@ def _finalize( apply_router_weight_on_input=apply_router_weight_on_input, ) dbo_yield_and_switch_from_compute_to_comm() + assert fused_expert_output.dtype == torch.bfloat16, ( + f"Expected fused_expert_output bfloat16, got {fused_expert_output.dtype}" + ) combined_x, _, event = self.buffer.combine( + # HT combine only supports BF16 x=fused_expert_output, handle=handle, topk_weights=None, diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 6712995b52af..500bcefcfaa9 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional, Union +from collections.abc import Callable import deep_ep import torch @@ -50,7 +50,31 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # DeepEP low-latency kernels are compiled only for certain # specific hidden sizes. - SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 6144, 7168] + # NOTE: Keep this list sorted, maybe_roundup_layer_hidden_size depends + # on it. + SUPPORTED_HIDDEN_SIZES = [2048, 2560, 3072, 4096, 5120, 6144, 7168, 8192] + + @staticmethod + def maybe_roundup_layer_hidden_size(hidden_size: int) -> int: + # Round up hidden size to the closest supported hidden size. + _supported_hs = DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES + # Check sorted + num_supported_hs = len(_supported_hs) + assert all( + [ + _supported_hs[i] < _supported_hs[i + 1] + for i in range(num_supported_hs - 1) + ] + ) + + for x in _supported_hs: + if x >= hidden_size: + return x + + raise ValueError( + f"Hidden Size {hidden_size} is greater than the " + f"maximum supported hidden size {_supported_hs[-1]}" + ) def __init__( self, @@ -67,28 +91,31 @@ def __init__( # The dispatch function returns a handle that the combine function # requires. We store the handle here so it is available to the # combine function. - self.handles: list[Optional[tuple]] = [None, None] + self.handles: list[tuple | None] = [None, None] self.num_dispatchers_ = num_dispatchers def num_dispatchers(self) -> int: return self.num_dispatchers_ + def output_is_reduced(self) -> bool: + return True + @property def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.BatchedExperts - def max_num_tokens_per_rank(self) -> Optional[int]: + def max_num_tokens_per_rank(self) -> int | None: return self.max_tokens_per_rank - def topk_indices_dtype(self) -> Optional[torch.dtype]: + def topk_indices_dtype(self) -> torch.dtype | None: return torch.int64 def _do_quant( self, - x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], a1_dtype: torch.dtype, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: if self.use_fp8_dispatch: block_k = ( quant_config.block_shape[1] @@ -134,7 +161,7 @@ def prepare_async( topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], + expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> tuple[Callable, mk.ReceiverType]: @@ -197,9 +224,9 @@ def prepare_async( def _receiver( self, - expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + expert_x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], expert_num_tokens: torch.Tensor, - a1_scale: Optional[torch.Tensor], + a1_scale: torch.Tensor | None, a1_dtype: torch.dtype, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: @@ -217,7 +244,7 @@ def prepare( topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], + expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index a2d8fe0da154..b7820319682b 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -90,16 +89,14 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # We use global_num_experts due to how moe_align_block_size handles # expert_maps. """ @@ -118,14 +115,12 @@ def workspace_shapes( - Note: in order for activation chunking to work, the first dimension of each tuple must be the number of tokens. """ - aq_m, aq_n = aq.shape + workspace1 = (M, K) workspace2 = (0,) - output_shape = (aq_m, aq_n * 2) if self.quant_dtype == "nvfp4" else (aq_m, aq_n) - workspace_dtype = a.dtype - workspace1 = output_shape + output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K) # The workspace is determined by `aq`, since it comes after any # potential communication op and is involved in the expert computation. - return (workspace1, workspace2, output_shape, workspace_dtype) + return (workspace1, workspace2, output_shape) def apply( self, @@ -137,13 +132,13 @@ def apply( topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: Optional[torch.Tensor], - workspace2: Optional[torch.Tensor], - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: Optional[bool], + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor | None, + workspace2: torch.Tensor | None, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool | None, ): assert activation == "silu", ( "Only activation silu is supported in FlashInferExperts" @@ -211,7 +206,7 @@ def flashinfer_cutlass_moe_fp4( inplace: bool = False, activation: str = "silu", global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, ) -> torch.Tensor: fused_experts = mk.FusedMoEModularKernel( @@ -246,7 +241,7 @@ def flashinfer_cutlass_moe( inplace: bool = False, activation: str = "silu", global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, tp_rank: int = 0, tp_size: int = 1, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 04bc987d0885..20e2f6c85186 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -11,6 +10,9 @@ ) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP, +) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils.flashinfer import nvfp4_block_scale_interleave @@ -36,15 +38,18 @@ def __init__( def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard - def max_num_tokens_per_rank(self) -> Optional[int]: + def max_num_tokens_per_rank(self) -> int | None: return None - def topk_indices_dtype(self) -> Optional[torch.dtype]: + def topk_indices_dtype(self) -> torch.dtype | None: return None def num_dispatchers(self) -> int: return self.num_dispatchers_ + def output_is_reduced(self) -> bool: + return False + def _apply_router_weight_on_input( self, a1: torch.Tensor, @@ -83,7 +88,7 @@ def prepare( topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], + expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: @@ -158,7 +163,7 @@ def prepare( topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], + expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: @@ -194,6 +199,8 @@ def finalize( apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, ) -> None: + assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceNoOP) + if self.use_dp: fused_expert_output = get_dp_group().reduce_scatterv( fused_expert_output, dim=0, sizes=get_local_sizes() diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index d12d05915566..f21fe16c5108 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -11,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def flashinfer_fused_moe_blockscale_fp8( @@ -105,7 +104,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake( def flashinfer_fused_moe_per_tensor_scale_fp8( routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], + routing_bias: torch.Tensor | None, hidden_states: torch.Tensor, input_scale: torch.Tensor, gemm1_weights: torch.Tensor, @@ -115,8 +114,8 @@ def flashinfer_fused_moe_per_tensor_scale_fp8( output2_scales_scalar: torch.Tensor, num_experts: int, top_k: int, - num_expert_group: Optional[int], - topk_group: Optional[int], + num_expert_group: int | None, + topk_group: int | None, intermediate_size: int, local_expert_offset: int, local_num_experts: int, @@ -163,7 +162,7 @@ def flashinfer_fused_moe_per_tensor_scale_fp8( def flashinfer_fused_moe_per_tensor_scale_fp8_fake( routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], + routing_bias: torch.Tensor | None, hidden_states: torch.Tensor, input_scale: torch.Tensor, gemm1_weights: torch.Tensor, @@ -173,8 +172,8 @@ def flashinfer_fused_moe_per_tensor_scale_fp8_fake( output2_scales_scalar: torch.Tensor, num_experts: int, top_k: int, - num_expert_group: Optional[int], - topk_group: Optional[int], + num_expert_group: int | None, + topk_group: int | None, intermediate_size: int, local_expert_offset: int, local_num_experts: int, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 02a935a1dca2..7fd8511e297d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -2,8 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused batched MoE kernel.""" -from typing import Optional - import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -370,8 +368,8 @@ def invoke_moe_batched_triton_kernel( expert_num_tokens: torch.Tensor, # [E] compute_type: tl.dtype, # Quantization data - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], + A_scale: torch.Tensor | None, + B_scale: torch.Tensor | None, B_zp: torch.Tensor, # Quantization schemes use_fp8_w8a8: bool, @@ -379,7 +377,7 @@ def invoke_moe_batched_triton_kernel( use_int4_w4a16: bool, config: dict[str, int], per_act_token_quant: bool, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, ): assert not use_int4_w4a16 max_num_tokens = A.size(1) @@ -500,22 +498,25 @@ def __init__( def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.BatchedExperts - def max_num_tokens_per_rank(self) -> Optional[int]: + def max_num_tokens_per_rank(self) -> int | None: return self.max_num_tokens - def topk_indices_dtype(self) -> Optional[torch.dtype]: + def topk_indices_dtype(self) -> torch.dtype | None: return None def num_dispatchers(self) -> int: return self.num_dispatchers_ + def output_is_reduced(self) -> bool: + return False + def prepare( self, a1: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], + expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: @@ -665,23 +666,20 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - assert a.dim() == 2 + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: num_dp = self.num_dispatchers num_experts = local_num_experts workspace13 = (num_experts, self.max_num_tokens * num_dp, K) workspace2 = (self.max_num_tokens * num_dp, N) output = workspace13 - return (workspace13, workspace2, output, a.dtype) + return (workspace13, workspace2, output) def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: assert self.quant_config.is_quantized @@ -701,12 +699,12 @@ def apply( topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): assert hidden_states.dim() == 3 @@ -754,15 +752,15 @@ def apply( def batched_moe_kernel_quantize_input( A: torch.Tensor, - A_scale: Optional[torch.Tensor], + A_scale: torch.Tensor | None, num_tokens: int, E: int, N: int, expert_num_tokens: torch.Tensor, - qtype: Optional[torch.dtype], + qtype: torch.dtype | None, per_act_token_quant: bool, - block_shape: Optional[list[int]] = None, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + block_shape: list[int] | None = None, +) -> tuple[torch.Tensor, torch.Tensor | None]: if torch.compiler.is_compiling() or torch.cuda.is_current_stream_capturing(): # Note: this does a bunch of extra work because expert_num_tokens is # ignored but it does support torch.compile + cudagraphs. @@ -862,24 +860,21 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - assert a.dim() == 2 + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: num_dp = self.num_dispatchers num_experts = local_num_experts max_num_tokens = self.max_num_tokens workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) output = (num_experts, max_num_tokens * num_dp, K) - return (workspace13, workspace2, output, a.dtype) + return (workspace13, workspace2, output) def apply( self, @@ -891,12 +886,12 @@ def apply( topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): # Check constraints. diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index c46cc016214f..3b0df6c416a0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -2,155 +2,109 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused MoE utilities for GPTQ.""" -from typing import Optional +from collections.abc import Callable import torch -from typing_extensions import override import vllm._custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + batched_moe_align_block_size, + moe_align_block_size, +) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP, ) -from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_inplace from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_make_workspace_new, marlin_moe_intermediate_size, maybe_warn_marlin_atomic_add, ) from vllm.scalar_type import ScalarType, scalar_types -from vllm.utils import direct_register_custom_op -def fused_marlin_moe( +def default_activation_func( + activation: str, output: torch.Tensor, input: torch.Tensor +) -> None: + if activation == "silu": + torch.ops._C.silu_and_mul(output, input) + elif activation == "swigluoai": + # alpha = 1.702, limit = 7.0 + torch.ops._C.swigluoai_and_mul(output, input) + else: + raise ValueError( + f"Unsupported activation: {activation}. " + "Only silu and swigluoai activations are supported." + ) + + +def _fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - bias1: Optional[torch.Tensor], - bias2: Optional[torch.Tensor], + bias1: torch.Tensor | None, + bias2: torch.Tensor | None, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - gating_output: Optional[torch.Tensor], topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - quant_type_id: int, - apply_router_weight_on_input: bool = False, - global_num_experts: int = -1, - activation: Optional[str] = "silu", - expert_map: Optional[torch.Tensor] = None, - global_scale1: Optional[torch.Tensor] = None, - global_scale2: Optional[torch.Tensor] = None, - g_idx1: Optional[torch.Tensor] = None, - g_idx2: Optional[torch.Tensor] = None, - sort_indices1: Optional[torch.Tensor] = None, - sort_indices2: Optional[torch.Tensor] = None, - w1_zeros: Optional[torch.Tensor] = None, - w2_zeros: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - intermediate_cache13: Optional[torch.Tensor] = None, - intermediate_cache2: Optional[torch.Tensor] = None, + num_topk: int, + quant_type: ScalarType, + apply_router_weight_on_input: bool, + expert_map: torch.Tensor | None, + block_size_m: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + activation: str = "silu", + activation_func: Callable[ + [str, torch.Tensor, torch.Tensor], None + ] = default_activation_func, + global_scale1: torch.Tensor | None = None, + global_scale2: torch.Tensor | None = None, + g_idx1: torch.Tensor | None = None, + g_idx2: torch.Tensor | None = None, + sort_indices1: torch.Tensor | None = None, + sort_indices2: torch.Tensor | None = None, + w1_zeros: torch.Tensor | None = None, + w2_zeros: torch.Tensor | None = None, + workspace: torch.Tensor | None = None, + intermediate_cache13: torch.Tensor | None = None, + intermediate_cache2: torch.Tensor | None = None, + output: torch.Tensor | None = None, is_k_full: bool = True, - output: Optional[torch.Tensor] = None, - inplace: bool = False, ) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - w1_scale (torch.Tensor): Scale to be used for w1. - - w2_scale (torch.Tensor): Scale to be used for w2. - - gating_output (Optional[torch.Tensor]): The output of the gating - operation (before softmax). - - g_idx1 (Optional[torch.Tensor]): The first set of act_order indices. - - g_idx2 (Optional[torch.Tensor]): The second set of act_order indices. - - sort_indices1 (Optional[torch.Tensor]): The first act_order input - permutation. - - sort_indices2 (Optional[torch.Tensor]): The second act_order input - permutation. - - topk_weights (torch.Tensor): Top-k weights. - - topk_ids (torch.Tensor): Indices of topk-k elements. - - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. - - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. - - num_bits (bool): The number of bits in expert weights quantization. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - quant_type = ScalarType.from_id(quant_type_id) - assert quant_type in [ - scalar_types.uint4, - scalar_types.uint8b128, - scalar_types.uint4b8, - scalar_types.float8_e4m3fn, - scalar_types.float4_e2m1f, - ] - - bit4_scalar_types = [ - scalar_types.uint4, - scalar_types.uint4b8, - scalar_types.float4_e2m1f, - ] - num_bits = 4 if quant_type in bit4_scalar_types else 8 - - # Check constraints. - if gating_output is not None: - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch" - ) - assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[1] == w2.shape[2] // (num_bits // 2), ( - "Hidden size mismatch w2" - ) - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [torch.float16, torch.bfloat16] - assert num_bits in [4, 8] - assert topk_weights.dtype == torch.float32 - - M, K = hidden_states.shape - E = w1.shape[0] + assert hidden_states.ndim == 2 + M, K = hidden_states.size() N = marlin_moe_intermediate_size(w1, w2) - topk = topk_ids.shape[1] - - # M block size selection logic - # TODO: tune this further for specific models - for block_size_m in [8, 16, 32, 48, 64]: - if M * topk / E / block_size_m < 0.9: - break - - if global_num_experts == -1: - global_num_experts = E - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, block_size_m, global_num_experts, expert_map - ) if workspace is None: workspace = marlin_make_workspace_new(hidden_states.device, 4) - if intermediate_cache2 is None: - intermediate_cache2 = torch.empty( - (M * topk, N), + if intermediate_cache13 is None: + intermediate_cache13 = torch.empty( + (M * num_topk * max(2 * N, K),), device=hidden_states.device, dtype=hidden_states.dtype, ) - if intermediate_cache13 is None: - intermediate_cache13 = torch.empty( - (M * topk * max(2 * N, K),), + if intermediate_cache2 is None: + intermediate_cache2 = torch.empty( + (M * num_topk, N), device=hidden_states.device, dtype=hidden_states.dtype, ) - intermediate_cache1 = _resize_cache(intermediate_cache13, (M * topk, 2 * N)) - intermediate_cache3 = _resize_cache(intermediate_cache13, (M * topk, K)) - intermediate_cache2 = _resize_cache(intermediate_cache2, (M * topk, N)) + intermediate_cache1 = _resize_cache(intermediate_cache13, (M * num_topk, 2 * N)) + + intermediate_cache3 = _resize_cache(intermediate_cache13, (M * num_topk, K)) + + intermediate_cache2 = _resize_cache(intermediate_cache2, (M * num_topk, N)) maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype) use_atomic_add = ( @@ -174,7 +128,7 @@ def fused_marlin_moe( num_tokens_post_padded, topk_weights, moe_block_size=block_size_m, - top_k=topk, + top_k=num_topk, mul_topk_weights=apply_router_weight_on_input, is_ep=expert_map is not None, b_q_type=quant_type, @@ -187,27 +141,19 @@ def fused_marlin_moe( is_zp_float=False, ) - if activation == "silu": - torch.ops._C.silu_and_mul( - intermediate_cache2, intermediate_cache1.view(-1, 2 * N) - ) - elif activation == "swigluoai": - # alpha = 1.702, limit = 7.0 - torch.ops._C.swigluoai_and_mul( - intermediate_cache2, intermediate_cache1.view(-1, 2 * N) - ) - else: - raise ValueError( - f"Unsupported activation: {activation}. " - "Only silu and swigluoai activations are supported." - ) + activation_func( + activation, intermediate_cache2, intermediate_cache1.view(-1, 2 * N) + ) + + if output is None: + output = intermediate_cache3 if expert_map is not None: - intermediate_cache3.zero_() + output.zero_() - intermediate_cache3 = ops.moe_wna16_marlin_gemm( + output = ops.moe_wna16_marlin_gemm( intermediate_cache2, - intermediate_cache3, + output, w2, bias2, w2_scale, @@ -225,65 +171,339 @@ def fused_marlin_moe( mul_topk_weights=not apply_router_weight_on_input, is_ep=expert_map is not None, b_q_type=quant_type, - size_m=M * topk, + size_m=M * num_topk, size_n=K, size_k=N, is_k_full=is_k_full, use_atomic_add=use_atomic_add, use_fp32_reduce=True, is_zp_float=False, - ).view(-1, topk, K) + ) - if output is None: - output = hidden_states if inplace else torch.empty_like(hidden_states) - return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output) + return output -def fused_marlin_moe_fake( +def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + bias1: torch.Tensor | None, + bias2: torch.Tensor | None, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - gating_output: Optional[torch.Tensor], + gating_output: torch.Tensor | None, topk_weights: torch.Tensor, topk_ids: torch.Tensor, quant_type_id: int, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, - global_scale1: Optional[torch.Tensor] = None, - global_scale2: Optional[torch.Tensor] = None, - expert_map: Optional[torch.Tensor] = None, - g_idx1: Optional[torch.Tensor] = None, - g_idx2: Optional[torch.Tensor] = None, - sort_indices1: Optional[torch.Tensor] = None, - sort_indices2: Optional[torch.Tensor] = None, - w1_zeros: Optional[torch.Tensor] = None, - w2_zeros: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - intermediate_cache13: Optional[torch.Tensor] = None, - intermediate_cache2: Optional[torch.Tensor] = None, + activation: str = "silu", + activation_func: Callable[ + [str, torch.Tensor, torch.Tensor], None + ] = default_activation_func, + moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None, + expert_map: torch.Tensor | None = None, + global_scale1: torch.Tensor | None = None, + global_scale2: torch.Tensor | None = None, + g_idx1: torch.Tensor | None = None, + g_idx2: torch.Tensor | None = None, + sort_indices1: torch.Tensor | None = None, + sort_indices2: torch.Tensor | None = None, + w1_zeros: torch.Tensor | None = None, + w2_zeros: torch.Tensor | None = None, + workspace: torch.Tensor | None = None, + intermediate_cache13: torch.Tensor | None = None, + intermediate_cache2: torch.Tensor | None = None, is_k_full: bool = True, - output: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, inplace: bool = False, ) -> torch.Tensor: - return torch.empty_like(hidden_states) + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - w1_scale (torch.Tensor): Scale to be used for w1. + - w2_scale (torch.Tensor): Scale to be used for w2. + - gating_output (torch.Tensor|None): The output of the gating + operation (before softmax). + - g_idx1 (torch.Tensor|None): The first set of act_order indices. + - g_idx2 (torch.Tensor|None): The second set of act_order indices. + - sort_indices1 (torch.Tensor|None): The first act_order input + permutation. + - sort_indices2 (torch.Tensor|None): The second act_order input + permutation. + - topk_weights (torch.Tensor): Top-k weights. + - topk_ids (torch.Tensor): Indices of topk-k elements. + - w1_zeros (torch.Tensor|None): Optional zero points to be used for w1. + - w2_zeros (torch.Tensor|None): Optional zero points to be used for w2. + - num_bits (bool): The number of bits in expert weights quantization. + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ -direct_register_custom_op( - op_name="fused_marlin_moe", - op_func=fused_marlin_moe, - fake_impl=fused_marlin_moe_fake, -) + if inplace: + assert output is None, "Conflicting request" + + quant_type = ScalarType.from_id(quant_type_id) + assert quant_type in [ + scalar_types.uint4, + scalar_types.uint8b128, + scalar_types.uint4b8, + scalar_types.float8_e4m3fn, + scalar_types.float4_e2m1f, + ] + + bit4_scalar_types = [ + scalar_types.uint4, + scalar_types.uint4b8, + scalar_types.float4_e2m1f, + ] + num_bits = 4 if quant_type in bit4_scalar_types else 8 + + M, K = hidden_states.size() + E = w1.size(0) + topk = topk_ids.size(1) + + # Check constraints. + if gating_output is not None: + assert gating_output.size(0) == M, "Number of tokens mismatch" + assert w1.size(1) * 16 == K, "Hidden size mismatch w1" + assert w2.size(2) // (num_bits // 2) == K, "Hidden size mismatch w2" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [torch.float16, torch.bfloat16] + assert num_bits in [4, 8] + assert topk_weights.dtype == torch.float32 + # M block size selection logic + # TODO: tune this further for specific models + for block_size_m in [8, 16, 32, 48, 64]: + if M * topk / E / block_size_m < 0.9: + break + + if global_num_experts == -1: + global_num_experts = E + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, block_size_m, global_num_experts, expert_map + ) + + assert activation is not None + moe_output = _fused_marlin_moe( + hidden_states=hidden_states, + w1=w1, + w2=w2, + bias1=bias1, + bias2=bias2, + w1_scale=w1_scale, + w2_scale=w2_scale, + topk_weights=topk_weights, + num_topk=topk, + quant_type=quant_type, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + block_size_m=block_size_m, + sorted_token_ids=sorted_token_ids, + expert_ids=expert_ids, + num_tokens_post_padded=num_tokens_post_padded, + activation=activation, + activation_func=activation_func, + global_scale1=global_scale1, + global_scale2=global_scale2, + g_idx1=g_idx1, + g_idx2=g_idx2, + sort_indices1=sort_indices1, + sort_indices2=sort_indices2, + w1_zeros=w1_zeros, + w2_zeros=w2_zeros, + workspace=workspace, + intermediate_cache13=intermediate_cache13, + intermediate_cache2=intermediate_cache2, + output=None, + is_k_full=is_k_full, + ).view(-1, topk, K) + + if output is None: + if inplace and not disable_inplace(): + output = hidden_states + else: + output = torch.empty_like(hidden_states) + + if moe_sum is None: + return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output) + else: + return moe_sum(moe_output, output) + + +def batched_fused_marlin_moe( + hidden_states: torch.Tensor, + expert_num_tokens: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + bias1: torch.Tensor | None, + bias2: torch.Tensor | None, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor | None, + quant_type_id: int, + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, + activation: str | None = "silu", + expert_map: torch.Tensor | None = None, + global_scale1: torch.Tensor | None = None, + global_scale2: torch.Tensor | None = None, + g_idx1: torch.Tensor | None = None, + g_idx2: torch.Tensor | None = None, + sort_indices1: torch.Tensor | None = None, + sort_indices2: torch.Tensor | None = None, + w1_zeros: torch.Tensor | None = None, + w2_zeros: torch.Tensor | None = None, + workspace: torch.Tensor | None = None, + intermediate_cache13: torch.Tensor | None = None, + intermediate_cache2: torch.Tensor | None = None, + is_k_full: bool = True, + output: torch.Tensor | None = None, + inplace: bool = False, +) -> torch.Tensor: + """ + This function massages the inputs so the batched hidden_states can be + presented as a 2D contiguous tensor that could be used with + _fused_marlin_moe. + + Note that both batched_fused_marlin_moe and fused_marlin_moe ultimately + use `ops.moe_wna16_marlin_gemm` for the gemm operation and + `ops.moe_mna16_marlin_gemm` supports only 2D contiguous hidden_states. + Note that the moe_align_block_size function indicates, + - What rows of the A matrix (hidden_states) to access during the + matmul, via sorted_ids output. + - What expert_id to use for each block matmul, via expert_ids ouptut. + + In the batched version, the tokens are already grouped/batched by experts + they subscribe to. Due to this, we can represent the batched hidden_states + tensor of shape [B, MAX_TOKENS_PER_BATCH, K] as a 2D tensor of shape, + [B * MAX_TOKENS_PER_BATCH, K]. We may treat this a 2D contiguous tensor + with topk=1 as each token (row in the tensor) subscribes to exactly one + expert_id (which is the batch_id). With the expert_num_tokens tensor, that + indicates how many tokens are actually valid in each batch, the + batched_moe_align_block_size function constructs the sorted_ids and + expert_ids tensors, so only relevant/valid rows of A (hidden_states) + are accessed and are processed with the correct expert_ids. + """ + + assert hidden_states.ndim == 3, ( + f"hidden states must be batched. e.g. [B, MAX_TOKENS, K]." + f"But got {hidden_states.size()}" + ) + if inplace: + assert output is None, "Conflicting request." + + quant_type = ScalarType.from_id(quant_type_id) + assert quant_type in [ + scalar_types.uint4, + scalar_types.uint8b128, + scalar_types.uint4b8, + scalar_types.float8_e4m3fn, + scalar_types.float4_e2m1f, + ] + + bit4_scalar_types = [ + scalar_types.uint4, + scalar_types.uint4b8, + scalar_types.float4_e2m1f, + ] + num_bits = 4 if quant_type in bit4_scalar_types else 8 + + B, BATCH_TOKENS_MAX, K = hidden_states.size() + M = hidden_states.view(-1, K).size(0) + E = w1.size(0) + + # Check constraints. + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert hidden_states.dtype in [torch.float16, torch.bfloat16] + assert expert_num_tokens.size(0) == E + assert B == E, ( + "Batch must be as big as number of experts as the tokens" + "are sorted into the batch/expert they belong to" + ) + assert w1.size(1) * 16 == K, "Hidden size mismatch w1" + assert w2.size(2) // (num_bits // 2) == K, "Hidden size mismatch w2" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert num_bits in [4, 8] + + # Technically, the tokens are already separated by their expert ids. + # Hidden-States can just be squeezed to have just 2 dimensions, + # [B * MAX_TOKENS, K] and top_k can be interpreted as just 1. + topk = 1 + + # TODO(varun) : Choose a decent block size like in fused_marlin_moe + block_size_m = 64 + + sorted_token_ids, expert_ids, num_tokens_post_padded = batched_moe_align_block_size( + max_tokens_per_batch=BATCH_TOKENS_MAX, + block_size=block_size_m, + expert_num_tokens=expert_num_tokens, + ) + + if output is None and inplace: + output = hidden_states + + # TODO (varun): This can be avoided by plumbing the marlin kernel to + # ignore topk_weights when topk_weights_ptr is a nullptr. + topk_weights = torch.ones( + (M, topk), device=hidden_states.device, dtype=torch.float32 + ) -class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute): + assert activation is not None + output = _fused_marlin_moe( + hidden_states=hidden_states.view(-1, K), + w1=w1, + w2=w2, + bias1=bias1, + bias2=bias2, + w1_scale=w1_scale, + w2_scale=w2_scale, + topk_weights=topk_weights, + num_topk=topk, + quant_type=quant_type, + apply_router_weight_on_input=apply_router_weight_on_input, + activation=activation, + expert_map=expert_map, + block_size_m=block_size_m, + sorted_token_ids=sorted_token_ids, + expert_ids=expert_ids, + num_tokens_post_padded=num_tokens_post_padded, + global_scale1=global_scale1, + global_scale2=global_scale2, + g_idx1=g_idx1, + g_idx2=g_idx2, + sort_indices1=sort_indices1, + sort_indices2=sort_indices2, + w1_zeros=w1_zeros, + w2_zeros=w2_zeros, + workspace=workspace, + intermediate_cache13=intermediate_cache13, + intermediate_cache2=intermediate_cache2, + output=output.view(-1, K) if output is not None else output, + is_k_full=is_k_full, + ) + + output = output.view(B, BATCH_TOKENS_MAX, K) + + return output + + +class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self, quant_config: FusedMoEQuantConfig): # TODO (varun) : Enable activation quantization assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" super().__init__(quant_config) - @override def moe_problem_size( self, a1: torch.Tensor, @@ -311,6 +531,11 @@ def moe_problem_size( return E, M, N, K, topk + +class MarlinExperts(MarlinExpertsBase): + def __init__(self, quant_config: FusedMoEQuantConfig): + super().__init__(quant_config) + def supports_expert_map(self) -> bool: return True @@ -331,16 +556,14 @@ def supports_chunking(self) -> bool: def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # Modular Kernel provisions output buffer from workspace1. However in # the fused_marlin_moe() function, the final torch.sum(), is defined # essentially as, @@ -360,7 +583,7 @@ def workspace_shapes( workspace2 = (M * topk * max(2 * N, K),) output = (M, K) - return (workspace1, workspace2, output, a.dtype) + return (workspace1, workspace2, output) def apply( self, @@ -372,12 +595,12 @@ def apply( topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): assert self.w1_scale is not None @@ -397,6 +620,8 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, activation=activation, + activation_func=self.activation, + moe_sum=self.moe_sum, expert_map=expert_map, output=output, # Workspaces are swapped in workspace_shapes() to account for proper @@ -404,3 +629,103 @@ def apply( intermediate_cache13=workspace2, intermediate_cache2=workspace13, ) + + def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: + ops.moe_sum(input, output) + + +def modular_marlin_fused_moe( + quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None +) -> mk.FusedMoEModularKernel: + return mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + MarlinExperts(quant_config), + shared_experts, + ) + + +class BatchedMarlinExperts(MarlinExpertsBase): + def __init__( + self, + max_num_tokens: int, + num_dispatchers: int, + quant_config: FusedMoEQuantConfig, + ): + super().__init__(quant_config) + self.max_num_tokens = max_num_tokens + self.num_dispatchers = num_dispatchers + + def supports_expert_map(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceDelegate() + + @property + def activation_formats( + self, + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return ( + mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts, + ) + + def supports_chunking(self) -> bool: + return False + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + num_dispatchers = self.num_dispatchers + num_experts = local_num_experts + max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens + workspace13 = (num_experts * max_num_tokens * num_dispatchers, max(K, N * 2)) + workspace2 = (num_experts * max_num_tokens * num_dispatchers, N) + output = (num_experts, max_num_tokens * num_dispatchers, K) + return (workspace13, workspace2, output) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ): + assert expert_tokens_meta is not None, "Num valid tokens per batch is required" + return batched_fused_marlin_moe( + hidden_states=hidden_states, + expert_num_tokens=expert_tokens_meta.expert_num_tokens, + w1=w1, + w2=w2, + bias1=self.w1_bias, + bias2=self.w2_bias, + w1_scale=self.w1_scale, + w2_scale=self.w2_scale, + gating_output=None, + quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16 + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + activation=activation, + expert_map=expert_map, + output=output, + intermediate_cache13=workspace13, + intermediate_cache2=workspace2, + ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 2a3abcaadebd..89e92edc8d2b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -5,7 +5,8 @@ import functools import json import os -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any import torch import torch.nn.functional as F @@ -14,6 +15,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, @@ -39,15 +43,17 @@ from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, activation_without_mul, + disable_inplace, moe_kernel_quantize_input, ) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme +from vllm.model_executor.utils import maybe_disable_graph_partition from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled @@ -538,10 +544,10 @@ def invoke_fused_moe_kernel( A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - B_zp: Optional[torch.Tensor], - topk_weights: Optional[torch.Tensor], + A_scale: torch.Tensor | None, + B_scale: torch.Tensor | None, + B_zp: torch.Tensor | None, + topk_weights: torch.Tensor | None, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, @@ -554,8 +560,8 @@ def invoke_fused_moe_kernel( use_int8_w8a16: bool, use_int4_w4a16: bool, per_channel_quant: bool, - block_shape: Optional[list[int]] = None, - B_bias: Optional[torch.Tensor] = None, + block_shape: list[int] | None = None, + B_bias: torch.Tensor | None = None, ) -> None: assert topk_weights is not None or not mul_routed_weight assert topk_weights is None or topk_weights.stride(1) == 1 @@ -807,7 +813,7 @@ def zero_experts_compute_triton( # Adapted from: https://github.com/sgl-project/sglang/pull/2628 def get_config_file_name( - E: int, N: int, dtype: Optional[str], block_shape: Optional[list[int]] = None + E: int, N: int, dtype: str | None, block_shape: list[int] | None = None ) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" @@ -822,10 +828,10 @@ def get_config_file_name( def get_moe_configs( E: int, N: int, - dtype: Optional[str], - block_n: Optional[int] = None, - block_k: Optional[int] = None, -) -> Optional[dict[int, Any]]: + dtype: str | None, + block_n: int | None = None, + block_k: int | None = None, +) -> dict[int, Any] | None: """ Return optimized configurations for the fused MoE kernel. @@ -835,6 +841,10 @@ def get_moe_configs( be picked and the associated configuration chosen to invoke the kernel. """ + # Avoid optimizing for the batch invariant case. Use default config + if vllm_is_batch_invariant(): + return None + # First look up if an optimized configuration is available in the configs # directory block_shape = [block_n, block_k] if block_n and block_k else None @@ -964,9 +974,18 @@ def get_default_config( N: int, K: int, topk: int, - dtype: Optional[str], - block_shape: Optional[list[int]] = None, + dtype: str | None, + block_shape: list[int] | None = None, ) -> dict[str, int]: + if vllm_is_batch_invariant(): + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + return config + if dtype == "fp8_w8a8" and block_shape is not None: # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] # BLOCK_SIZE_K must be divisible by block_shape[1] @@ -1015,9 +1034,9 @@ def try_get_optimal_moe_config( w1_shape: tuple[int, ...], w2_shape: tuple[int, ...], top_k: int, - dtype: Optional[str], + dtype: str | None, M: int, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, ) -> dict[str, int]: from vllm.model_executor.layers.fused_moe import get_config @@ -1055,9 +1074,8 @@ def vllm_topk_softmax( topk_indices, token_expert_indices, gating_output, + renormalize, ) - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_indices @@ -1075,7 +1093,7 @@ def fused_topk( gating_output: torch.Tensor, topk: int, renormalize: bool, - indices_type: Optional[torch.dtype] = None, + indices_type: torch.dtype | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" @@ -1094,11 +1112,9 @@ def fused_topk( M, topk, dtype=torch.int32, device=hidden_states.device ) - gating_output_float = gating_output.float() # TODO(woosuk): Optimize this. - topk_func = dispatch_topk_func() topk_weights, topk_ids = topk_func( - topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize + topk_weights, topk_ids, token_expert_indices, gating_output, renormalize ) return topk_weights, topk_ids, token_expert_indices @@ -1116,7 +1132,10 @@ def fused_topk_bias( scores_for_choice = scores.view( -1, n_routed_experts ) + e_score_correction_bias.unsqueeze(0) - topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1] + + # For batch invariance, use sorted=True to ensure deterministic expert selection + use_sorted = vllm_is_batch_invariant() + topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1] topk_weights = scores.gather(1, topk_indices) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -1124,7 +1143,11 @@ def fused_topk_bias( # This is used by the Deepseek-V2 and Deepseek-V3 model -@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +@torch.compile( + dynamic=True, + backend=current_platform.simple_compile_backend, + options=maybe_disable_graph_partition(current_platform.simple_compile_backend), +) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -1134,7 +1157,7 @@ def grouped_topk( topk_group: int = 0, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if ( envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK @@ -1177,7 +1200,10 @@ def grouped_topk( group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values ) # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + + # For batch invariance, use sorted=True to ensure deterministic expert selection + use_sorted = vllm_is_batch_invariant() + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[ 1 ] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] @@ -1190,11 +1216,13 @@ def grouped_topk( tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] if e_score_correction_bias is not None: - topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1] # Use original unbiased scores for the routing weights topk_weights = original_scores.gather(1, topk_ids) else: - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + topk_weights, topk_ids = torch.topk( + tmp_scores, k=topk, dim=-1, sorted=use_sorted + ) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -1210,7 +1238,7 @@ def eplb_map_to_physical_and_record( expert_load_view: torch.Tensor, logical_to_physical_map: torch.Tensor, logical_replica_count: torch.Tensor, - indices_type: Optional[torch.dtype] = None, + indices_type: torch.dtype | None = None, ) -> torch.Tensor: """ Map the logical expert ids to physical expert ids @@ -1325,19 +1353,19 @@ def inplace_fused_experts( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - ocp_mx_scheme: Optional[str] = None, + ocp_mx_scheme: str | None = None, per_channel_quant: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + w1_zp: torch.Tensor | None = None, + w2_zp: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + block_shape: list[int] | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, ) -> None: fused_experts_impl( hidden_states, @@ -1380,19 +1408,19 @@ def inplace_fused_experts_fake( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - ocp_mx_scheme: Optional[str] = None, + ocp_mx_scheme: str | None = None, per_channel_quant: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + w1_zp: torch.Tensor | None = None, + w2_zp: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + block_shape: list[int] | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, ) -> None: pass @@ -1422,19 +1450,19 @@ def outplace_fused_experts( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - ocp_mx_scheme: Optional[str] = None, + ocp_mx_scheme: str | None = None, per_channel_quant: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + w1_zp: torch.Tensor | None = None, + w2_zp: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + block_shape: list[int] | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, ) -> torch.Tensor: return fused_experts_impl( hidden_states, @@ -1476,19 +1504,19 @@ def outplace_fused_experts_fake( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - ocp_mx_scheme: Optional[str] = None, + ocp_mx_scheme: str | None = None, per_channel_quant: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + w1_zp: torch.Tensor | None = None, + w2_zp: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + block_shape: list[int] | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1516,7 +1544,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: - if inplace: + if inplace and not disable_inplace(): return torch_vllm_inplace_fused_experts return torch_vllm_outplace_fused_experts @@ -1533,8 +1561,8 @@ def fused_experts( activation: str = "silu", apply_router_weight_on_input: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - quant_config: Optional[FusedMoEQuantConfig] = None, + expert_map: torch.Tensor | None = None, + quant_config: FusedMoEQuantConfig | None = None, allow_deep_gemm: bool = False, allow_cutlass_block_scaled_grouped_gemm: bool = False, ) -> torch.Tensor: @@ -1619,13 +1647,14 @@ def fused_experts( SILU_NO_MUL: str = activation_without_mul("silu") GELU_NO_MUL: str = activation_without_mul("gelu") +RELU2_NO_MUL: str = activation_without_mul("relu2") def _get_config_quant_dtype( use_fp8_w8a8: bool, use_int8_w8a8: bool, - ocp_mx_scheme: Optional[str], -) -> Union[None, torch.dtype, str]: + ocp_mx_scheme: str | None, +) -> None | torch.dtype | str: """ Get the quantization type based on the quantization strategy flags. We don't have a quant_config at this point so we need to work backwards. @@ -1659,19 +1688,19 @@ def fused_experts_impl( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - ocp_mx_scheme: Optional[str] = None, + ocp_mx_scheme: str | None = None, per_channel_quant: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[list[int]] = None, - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + w1_zp: torch.Tensor | None = None, + w2_zp: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + block_shape: list[int] | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, ) -> torch.Tensor: # Check constraints. if use_int4_w4a16: @@ -1766,7 +1795,10 @@ def fused_experts_impl( else: raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") - out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states) + if inplace and not disable_inplace(): + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) if ocp_mx_scheme is not None: # TODO: On platforms for which `current_platform.supports_mx()` is True @@ -1883,7 +1915,8 @@ def fused_experts_impl( intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) elif activation == GELU_NO_MUL: intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N)) - + elif activation == RELU2_NO_MUL: + intermediate_cache2 = torch.square(F.relu(intermediate_cache1.view(-1, N))) else: raise ValueError(f"Unsupported FusedMoe activation: {activation}.") @@ -1954,20 +1987,18 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: workspace1 = (M, topk, max(N // 2, K)) workspace2 = (M, topk, max(N, K)) output = (M, K) - return (workspace1, workspace2, output, a.dtype) + return (workspace1, workspace2, output) def apply( self, @@ -1979,12 +2010,12 @@ def apply( topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): # Check constraints. @@ -2072,7 +2103,7 @@ def apply( activation, intermediate_cache2, intermediate_cache1.view(-1, N) ) - a2q_scale: Optional[torch.Tensor] = None + a2q_scale: torch.Tensor | None = None qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( intermediate_cache2, @@ -2106,13 +2137,18 @@ def apply( B_bias=self.w2_bias, ) - ops.moe_sum(intermediate_cache3, output) + # separate function is required for MoE + LoRA + self.moe_sum(intermediate_cache3, output) + + def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: + ops.moe_sum(input, output) def modular_triton_fused_moe( - quant_config: FusedMoEQuantConfig, + quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), TritonExperts(quant_config), + shared_experts, ) diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 39faeed5d10f..badedfc54c38 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -14,7 +13,7 @@ TopKWeightAndReduceNoOP, ) from vllm.triton_utils import tl, triton -from vllm.utils import has_triton_kernels +from vllm.utils.import_utils import has_triton_kernels logger = init_logger(__name__) @@ -80,10 +79,10 @@ def triton_kernel_moe_forward( topk: int, renormalize: bool, activation: str = "silu", - quant_config: Optional[FusedMoEQuantConfig] = None, + quant_config: FusedMoEQuantConfig | None = None, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, ) -> torch.Tensor: routing_data, gather_idx, scatter_idx = routing( gating_output, topk, sm_first=not renormalize @@ -115,13 +114,13 @@ def triton_kernel_fused_experts( gather_indx, # GatherIndx scatter_indx, # ScatterIndx activation: str = "silu", - quant_config: Optional[FusedMoEQuantConfig] = None, + quant_config: FusedMoEQuantConfig | None = None, swiglu_alpha: float = 1.702, swiglu_limit: float = 7.0, apply_router_weight_on_input: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - a1q_scale: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, + a1q_scale: torch.Tensor | None = None, ) -> torch.Tensor: if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG @@ -255,21 +254,19 @@ def supports_chunking(self) -> bool: def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # workspace are allocated inside the kernel workspace1 = (M, K) workspace2 = (0, 0) output = (M, K) - return (workspace1, workspace2, output, a.dtype) + return (workspace1, workspace2, output) def apply( self, @@ -281,12 +278,12 @@ def apply( topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): if expert_map is not None: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 9c8ccc6ec008..71393f4f6c27 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2,17 +2,18 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import abstractmethod -from collections.abc import Iterable +from collections.abc import Callable, Iterable from contextlib import nullcontext from enum import Enum -from typing import Callable, Literal, Optional, Union, get_args, overload +from functools import partial +from typing import Literal, get_args, overload import torch import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter import vllm.envs as envs -from vllm.config import get_current_vllm_config +from vllm.config import VllmConfig, get_current_vllm_config from vllm.config.parallel import ExpertPlacementStrategy from vllm.distributed import ( get_dp_group, @@ -39,6 +40,8 @@ FusedMoEPrepareAndFinalize, ) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + init_aiter_topK_meta_data, + is_rocm_aiter_fusion_shared_expert_enabled, is_rocm_aiter_moe_enabled, ) from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator @@ -46,11 +49,16 @@ QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + is_flashinfer_supporting_global_sf, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import cdiv, direct_register_custom_op, has_deep_ep, has_pplx, round_up +from vllm.utils import cdiv, round_up from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +from vllm.utils.import_utils import has_deep_ep, has_pplx +from vllm.utils.torch_utils import current_stream, direct_register_custom_op from vllm.v1.worker.ubatching import dbo_current_ubatch_id if current_platform.is_cuda_alike(): @@ -70,15 +78,15 @@ ) else: fused_experts = None # type: ignore - FusedMoEPermuteExpertsUnpermute = None # type: ignore - FusedMoEPrepareAndFinalize = None # type: ignore + FusedMoEPermuteExpertsUnpermute = object # type: ignore + FusedMoEPrepareAndFinalize = object # type: ignore def _eplb_map_to_physical_and_record( topk_ids: torch.Tensor, expert_load_view: torch.Tensor, logical_to_physical_map: torch.Tensor, logical_replica_count: torch.Tensor, - indices_type: Optional[torch.dtype], + indices_type: torch.dtype | None, ) -> torch.Tensor: # CPU fallback: no EPLB so just return as is return topk_ids @@ -87,7 +95,7 @@ def _eplb_map_to_physical_and_record( if is_rocm_aiter_moe_enabled(): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 - rocm_aiter_grouped_topk as grouped_topk, + rocm_aiter_grouped_topk as grouped_topk_aiter, ) else: from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk @@ -110,8 +118,8 @@ class FusedMoEMethodBase(QuantizeMethodBase): def __init__(self, moe: FusedMoEConfig): super().__init__() self.moe = moe - self.moe_quant_config: Optional[FusedMoEQuantConfig] = None - self.fused_experts: Optional[FusedMoEModularKernel] = None + self.moe_quant_config: FusedMoEQuantConfig | None = None + self.fused_experts: FusedMoEModularKernel | None = None self.topk_indices_dtype = None @abstractmethod @@ -139,12 +147,12 @@ def uses_weight_scale_2_pattern(self) -> bool: @staticmethod def _maybe_make_prepare_finalize( moe: FusedMoEConfig, - quant_config: Optional[FusedMoEQuantConfig], - ) -> Optional[FusedMoEPrepareAndFinalize]: + quant_config: FusedMoEQuantConfig | None, + ) -> FusedMoEPrepareAndFinalize | None: all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None - prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None + prepare_finalize: FusedMoEPrepareAndFinalize | None = None # TODO: could allow this now assert not moe.use_flashinfer_cutlass_kernels, "Must be created in modelopt.py" @@ -229,7 +237,7 @@ def _maybe_make_prepare_finalize( return prepare_finalize - def maybe_make_prepare_finalize(self) -> Optional[FusedMoEPrepareAndFinalize]: + def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: if self.moe.moe_parallel_config.use_all2all_kernels: return FusedMoEMethodBase._maybe_make_prepare_finalize( self.moe, self.moe_quant_config @@ -280,9 +288,13 @@ def select_gemm_impl( @abstractmethod def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: raise NotImplementedError + @property + def using_modular_kernel(self) -> bool: + return self.fused_experts is not None + @abstractmethod def apply( self, @@ -292,21 +304,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError @@ -356,15 +368,17 @@ def __init__(self, moe: FusedMoEConfig): logger.info_once( "FlashInfer CUTLASS MoE is available for EP" " but not enabled, consider setting" - " VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it." + " VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.", + scope="local", ) elif self.moe.moe_parallel_config.dp_size > 1: logger.info_once( - "FlashInfer CUTLASS MoE is currently not available for DP." + "FlashInfer CUTLASS MoE is currently not available for DP.", + scope="local", ) self.flashinfer_cutlass_moe = None # type: ignore - def maybe_make_prepare_finalize(self) -> Optional[FusedMoEPrepareAndFinalize]: + def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: if self.rocm_aiter_moe_enabled: return None else: @@ -399,11 +413,15 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): + if self.moe.is_act_and_mul: + w13_up_dim = 2 * intermediate_size_per_partition + else: + w13_up_dim = intermediate_size_per_partition # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter( torch.empty( num_experts, - 2 * intermediate_size_per_partition, + w13_up_dim, hidden_size, dtype=params_dtype, ), @@ -413,9 +431,7 @@ def create_weights( set_weight_attrs(w13_weight, extra_weight_attrs) if self.moe.has_bias: w13_bias = torch.nn.Parameter( - torch.zeros( - num_experts, 2 * intermediate_size_per_partition, dtype=params_dtype - ), + torch.zeros(num_experts, w13_up_dim, dtype=params_dtype), requires_grad=False, ) layer.register_parameter("w13_bias", w13_bias) @@ -528,21 +544,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if enable_eplb: assert expert_load_view is not None assert logical_to_physical_map is not None @@ -574,7 +590,7 @@ def apply( def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: if self.moe.has_bias: return biased_moe_quant_config( layer.w13_bias, @@ -591,21 +607,21 @@ def forward_cuda( top_k: int, router_logits: torch.Tensor, renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: zero_expert_num = getattr(layer, "zero_expert_num", 0) zero_expert_type = getattr(layer, "zero_expert_type", None) @@ -630,6 +646,7 @@ def forward_cuda( global_num_experts=global_num_experts, zero_expert_num=zero_expert_num, zero_expert_type=zero_expert_type, + num_fused_shared_experts=layer.num_fused_shared_experts, ) if self.rocm_aiter_moe_enabled: @@ -701,21 +718,21 @@ def forward_cpu( top_k: int, router_logits: torch.Tensor, renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if ( enable_eplb is not False or expert_load_view is not None @@ -750,21 +767,21 @@ def forward_xpu( top_k: int, router_logits: torch.Tensor, renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if ( enable_eplb is not False or expert_load_view is not None @@ -791,21 +808,21 @@ def forward_tpu( top_k: int, router_logits: torch.Tensor, renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert not use_grouped_topk assert num_expert_group is None assert topk_group is None @@ -856,7 +873,8 @@ def determine_expert_map( ep_rank: int, global_num_experts: int, expert_placement_strategy: ExpertPlacementStrategy = "linear", -) -> tuple[int, Optional[torch.Tensor]]: + num_fused_shared_experts: int = 0, +) -> tuple[int, torch.Tensor | None, torch.Tensor | None]: """ Calculates how many experts should be assigned to each rank for EP and creates a mapping from global to local expert index. Experts are @@ -878,10 +896,16 @@ def determine_expert_map( (global_num_experts,) mapping from global to local index. Contains -1 for experts not assigned to the current rank. Returns None if ep_size is 1. + - expert_mask (Optional[torch.Tensor]): A tensor of shape + (global_num_experts + num_fused_shared_experts + 1,) + containing 1 for experts assigned to the current rank + and 0 for sentinel. + Returns None if ep_size is 1. + Used only when AITER MOE is enabled. """ assert ep_size > 0 if ep_size == 1: - return (global_num_experts, None) + return (global_num_experts, None, None) # Distribute experts as evenly as possible to each rank. base_experts = global_num_experts // ep_size @@ -910,7 +934,26 @@ def determine_expert_map( f"'{expert_placement_strategy}', expected one of " f"{get_args(ExpertPlacementStrategy)}" ) - return (local_num_experts, expert_map) + + expert_mask = None + if is_rocm_aiter_moe_enabled(): + expert_mask = torch.ones( + (global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32 + ) + expert_mask[-1] = 0 + expert_mask[:global_num_experts] = expert_map > -1 + expert_map = torch.cat( + ( + expert_map, + torch.tensor( + [local_num_experts + i for i in range(num_fused_shared_experts)], + dtype=torch.int32, + ), + ), + dim=0, + ) + + return (local_num_experts, expert_map, expert_mask) def get_compressed_expert_map(expert_map: torch.Tensor) -> str: @@ -937,7 +980,7 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str: def maybe_roundup_hidden_size( hidden_size: int, act_dtype: torch.dtype, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, moe_parallel_config: FusedMoEParallelConfig, ) -> int: """ @@ -960,6 +1003,11 @@ def maybe_roundup_hidden_size( hidden_size, act_dtype ) + if moe_parallel_config.use_deepep_ll_kernels: + hidden_size = DeepEPLLPrepareAndFinalize.maybe_roundup_layer_hidden_size( + hidden_size + ) + # we are padding globally so EP buffer allocation works if quant_config and quant_config.get_name() == "mxfp4": from vllm.model_executor.layers.quantization.mxfp4 import ( @@ -1012,37 +1060,51 @@ def __init__( top_k: int, hidden_size: int, intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, + params_dtype: torch.dtype | None = None, reduce_results: bool = False, renormalize: bool = True, use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - ep_size: Optional[int] = None, - dp_size: Optional[int] = None, + num_expert_group: int | None = None, + topk_group: int | None = None, + quant_config: QuantizationConfig | None = None, + tp_size: int | None = None, + ep_size: int | None = None, + dp_size: int | None = None, prefix: str = "", - custom_routing_function: Optional[Callable] = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + is_act_and_mul: bool = True, enable_eplb: bool = False, num_redundant_experts: int = 0, has_bias: bool = False, is_sequence_parallel=False, - zero_expert_num: Optional[int] = 0, - zero_expert_type: Optional[str] = None, - expert_mapping: Optional[list[tuple[str, str, int, str]]] = None, + zero_expert_num: int | None = 0, + zero_expert_type: str | None = None, + expert_mapping: list[tuple[str, str, int, str]] | None = None, + n_shared_experts: int | None = None, ): super().__init__() + + # Allow disabling of the separate shared experts stream for + # debug purposes. + # TODO: Remove this after more extensive testings with TP/DP + # and other execution modes + if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM: + logger.info_once("Disabling MoE shared_experts cuda stream") + self.shared_experts_stream = None + else: + self.shared_experts_stream = torch.cuda.Stream() + if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype vllm_config = get_current_vllm_config() + self.vllm_config = vllm_config # FIXME (varun): We should have a better way of inferring the activation # datatype. This works for now as the tensor datatype entering the MoE @@ -1088,9 +1150,25 @@ def __init__( self.layer_name = prefix self.enable_eplb = enable_eplb - self.expert_load_view: Optional[torch.Tensor] = None - self.logical_to_physical_map: Optional[torch.Tensor] = None - self.logical_replica_count: Optional[torch.Tensor] = None + self.expert_load_view: torch.Tensor | None = None + self.logical_to_physical_map: torch.Tensor | None = None + self.logical_replica_count: torch.Tensor | None = None + + # ROCm aiter shared experts fusion + self.num_fused_shared_experts = ( + n_shared_experts + if n_shared_experts is not None + and is_rocm_aiter_fusion_shared_expert_enabled() + else 0 + ) + if ( + not is_rocm_aiter_fusion_shared_expert_enabled() + and self.num_fused_shared_experts != 0 + ): + raise ValueError( + "n_shared_experts is only supported on ROCm aiter when " + "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled" + ) # Determine expert maps if self.use_ep: @@ -1124,15 +1202,17 @@ def __init__( ) expert_placement_strategy = "linear" - self.expert_map: Optional[torch.Tensor] - local_num_experts, expert_map = determine_expert_map( + self.expert_map: torch.Tensor | None + local_num_experts, expert_map, expert_mask = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts, expert_placement_strategy=expert_placement_strategy, + num_fused_shared_experts=self.num_fused_shared_experts, ) self.local_num_experts = local_num_experts self.register_buffer("expert_map", expert_map) + self.register_buffer("expert_mask", expert_mask) logger.info_once( "[EP Rank %s/%s] Expert parallelism is enabled. Expert " "placement strategy: %s. Local/global" @@ -1146,10 +1226,18 @@ def __init__( get_compressed_expert_map(self.expert_map), ) else: - self.local_num_experts, self.expert_map = (self.global_num_experts, None) + self.local_num_experts, self.expert_map, self.expert_mask = ( + self.global_num_experts, + None, + None, + ) self.top_k = top_k + self._init_aiter_shared_experts_topK_buffer( + vllm_config=vllm_config, dp_size=dp_size_ + ) + assert intermediate_size % self.tp_size == 0 self.hidden_size = hidden_size self.intermediate_size_per_partition = intermediate_size // self.tp_size @@ -1181,14 +1269,15 @@ def __init__( in_dtype=moe_in_dtype, max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, has_bias=has_bias, + is_act_and_mul=is_act_and_mul, ) self.moe_config = moe - self.moe_quant_config: Optional[FusedMoEQuantConfig] = None + self.moe_quant_config: FusedMoEQuantConfig | None = None self.quant_config = quant_config # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. - quant_method: Optional[QuantizeMethodBase] = None + quant_method: QuantizeMethodBase | None = None quant_method = ( UnquantizedFusedMoEMethod(moe) if quant_config is None @@ -1201,6 +1290,24 @@ def __init__( assert isinstance(quant_method, FusedMoEMethodBase) self.quant_method = quant_method + if not self.moe_config.is_act_and_mul: + # Avoid circular import + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptFp8MoEMethod, + ) + + if not isinstance( + quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod) + ): + raise NotImplementedError( + "is_act_and_mul=False is supported only for unquantized " + "and ModelOpt FP8 moe for now" + ) + if not current_platform.is_cuda(): + raise NotImplementedError( + "is_act_and_mul=False is supported only for CUDA for now" + ) + if self.enable_eplb: from vllm.model_executor.layers.quantization.fp8 import Fp8MoEMethod @@ -1222,6 +1329,7 @@ def __init__( "intermediate_size_per_partition": self.intermediate_size_per_partition, "params_dtype": params_dtype, "weight_loader": self.weight_loader, + "global_num_experts": self.global_num_experts, } # need full intermediate size pre-sharding for WNA16 act order if self.quant_method.__class__.__name__ in ( @@ -1234,45 +1342,15 @@ def __init__( self.quant_method.create_weights(layer=self, **moe_quant_params) # Chunked all2all staging tensor - self.batched_hidden_states: Optional[torch.Tensor] = None - self.batched_router_logits: Optional[torch.Tensor] = None - - # TODO(bnell): flashinfer uses non-batched format. - # Does it really need a batched buffer? - if ( - self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels - or self.moe_config.use_flashinfer_cutlass_kernels - ): - if vllm_config.parallel_config.enable_dbo: - self.batched_hidden_states = torch.zeros( - (2, moe.max_num_tokens, self.hidden_size), - dtype=moe.in_dtype, - device=torch.cuda.current_device(), - ) + self.batched_hidden_states: torch.Tensor | None = None + self.batched_router_logits: torch.Tensor | None = None - # Note here we use `num_experts` which is logical expert count - self.batched_router_logits = torch.zeros( - (2, moe.max_num_tokens, num_experts), - dtype=moe.in_dtype, - device=torch.cuda.current_device(), - ) - else: - self.batched_hidden_states = torch.zeros( - (moe.max_num_tokens, self.hidden_size), - dtype=moe.in_dtype, - device=torch.cuda.current_device(), - ) - - # Note here we use `num_experts` which is logical expert count - self.batched_router_logits = torch.zeros( - (moe.max_num_tokens, num_experts), - dtype=moe.in_dtype, - device=torch.cuda.current_device(), - ) + @property + def shared_experts(self) -> torch.nn.Module | None: + return None @property - def shared_experts(self) -> Optional[torch.nn.Module]: + def gate(self) -> torch.nn.Module | None: return None @property @@ -1323,17 +1401,35 @@ def use_flashinfer_cutlass_kernels(self): and self.moe_config.use_flashinfer_cutlass_kernels ) + @property + def use_dp_chunking(self) -> bool: + return ( + self.moe_parallel_config.use_pplx_kernels + or self.moe_parallel_config.use_deepep_ll_kernels + or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels) + ) + + @property + def is_internal_router(self) -> bool: + # By default, router/gate is called before FusedMoE forward pass + return False + def update_expert_map(self): # ep_size and ep_rank should already be updated assert self.expert_map is not None with self.expert_map.device: - local_num_experts, expert_map = determine_expert_map( + local_num_experts, expert_map, expert_mask = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts, + num_fused_shared_experts=self.num_fused_shared_experts, ) self.local_num_experts = local_num_experts self.register_buffer("expert_map", expert_map) + self.register_buffer("expert_mask", expert_mask) + self._init_aiter_shared_experts_topK_buffer( + vllm_config=get_current_vllm_config(), dp_size=get_dp_group().world_size + ) def _load_per_tensor_weight_scale( self, @@ -1438,7 +1534,10 @@ def _load_w13( ): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim - shard_size = expert_data.shape[shard_dim] // 2 + if self.moe_config.is_act_and_mul: + shard_size = expert_data.shape[shard_dim] // 2 + else: + shard_size = expert_data.shape[shard_dim] if not load_full: loaded_weight = loaded_weight.narrow( shard_dim, shard_size * tp_rank, shard_size @@ -1504,6 +1603,24 @@ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: return expert_id return self.expert_map[expert_id].item() + def _init_aiter_shared_experts_topK_buffer( + self, vllm_config: VllmConfig, dp_size: int + ): + if is_rocm_aiter_fusion_shared_expert_enabled(): + if self.num_fused_shared_experts > 0: + init_aiter_topK_meta_data( + n_routed_experts=self.global_num_experts, + n_shared_experts=self.num_fused_shared_experts, + top_k=self.top_k, + tp_rank=self.ep_rank if self.use_ep else self.tp_rank, + tp_size=self.ep_size if self.use_ep else self.tp_size, + shared_experts_score=1.0, + max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens + * dp_size, + is_EP=self.use_ep, + ) + self.local_num_experts += self.num_fused_shared_experts + @overload def weight_loader( self, @@ -1534,7 +1651,7 @@ def weight_loader( shard_id: str, expert_id: int, return_success: bool = False, - ) -> Optional[bool]: + ) -> bool | None: if self.quant_config and self.quant_config.get_name() == "mxfp4": # (FIXME) for gpt-oss all experts are combined if "bias" in weight_name: @@ -1546,13 +1663,25 @@ def weight_loader( param.data[:, :dim1, :dim2].copy_(loaded_weight) return True if return_success else None - expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) - if expert_id == -1: + quant_method_name = self.quant_method.__class__.__name__ + global_expert_id = expert_id + expert_id = self._map_global_expert_id_to_local_expert_id(global_expert_id) + + allow_flashinfer = getattr(self.quant_method, "allow_flashinfer", False) + moe_backend = getattr(self.quant_method, "flashinfer_moe_backend", None) + + use_global_sf = ( + allow_flashinfer + and is_flashinfer_supporting_global_sf(moe_backend) + and "input_scale" in weight_name + and quant_method_name == "ModelOptNvFp4FusedMoE" + ) + + if expert_id == -1 and not use_global_sf: # Failed to load this param since it's not local to this rank return False if return_success else None # Hereafter, `expert_id` is local physical id - quant_method_name = self.quant_method.__class__.__name__ # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality @@ -1637,7 +1766,9 @@ def weight_loader( ) self._load_single_value( - param=param, loaded_weight=loaded_weight, expert_id=expert_id + param=param, + loaded_weight=loaded_weight, + expert_id=global_expert_id if use_global_sf else expert_id, ) return True if return_success else None @@ -1838,12 +1969,40 @@ def set_eplb_state( self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] self.logical_replica_count = logical_replica_count[moe_layer_idx] - def ensure_moe_quant_config(self): + def ensure_moe_quant_config_init(self): if self.quant_method.moe_quant_config is None: self.quant_method.moe_quant_config = ( self.quant_method.get_fused_moe_quant_config(self) ) + if self.moe_quant_config is None: + self.moe_quant_config = self.quant_method.moe_quant_config + + def ensure_dp_chunking_init(self): + if not self.use_dp_chunking or self.batched_hidden_states is not None: + return + + states_shape: tuple[int, ...] + logits_shape: tuple[int, ...] + + moe = self.moe_config + + # Note here we use `num_experts` which is logical expert count + if self.vllm_config.parallel_config.enable_dbo: + states_shape = (2, moe.max_num_tokens, self.hidden_size) + logits_shape = (2, moe.max_num_tokens, moe.num_experts) + else: + states_shape = (moe.max_num_tokens, self.hidden_size) + logits_shape = (moe.max_num_tokens, moe.num_experts) + + self.batched_hidden_states = torch.zeros( + states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() + ) + + self.batched_router_logits = torch.zeros( + logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() + ) + @staticmethod def select_experts( hidden_states: torch.Tensor, @@ -1851,21 +2010,22 @@ def select_experts( top_k: int, use_grouped_topk: bool, renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, - indices_type: Optional[torch.dtype] = None, + e_score_correction_bias: torch.Tensor | None = None, + indices_type: torch.dtype | None = None, enable_eplb: bool = False, - expert_map: Optional[torch.Tensor] = None, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - global_num_experts: Optional[int] = None, - zero_expert_num: Optional[int] = None, - zero_expert_type: Optional[str] = None, + expert_map: torch.Tensor | None = None, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + global_num_experts: int | None = None, + zero_expert_num: int | None = None, + zero_expert_type: str | None = None, + num_fused_shared_experts: int = 0, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Route the input hidden states to the top-k experts based on the @@ -1900,7 +2060,16 @@ def select_experts( if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None - topk_weights, topk_ids = grouped_topk( + if is_rocm_aiter_moe_enabled(): + if not is_rocm_aiter_fusion_shared_expert_enabled(): + assert num_fused_shared_experts == 0 + grouped_topk_impl = partial( + grouped_topk_aiter, + num_fused_shared_experts=num_fused_shared_experts, + ) + else: + grouped_topk_impl = grouped_topk + topk_weights, topk_ids = grouped_topk_impl( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, @@ -1987,21 +2156,17 @@ def must_reduce_shared_expert_outputs(self) -> bool: Therefore it is required that we reduce the shared_experts output early. """ + assert self.quant_method is not None return ( - self.use_pplx_kernels - or self.use_deepep_ht_kernels - or self.use_deepep_ll_kernels + self.quant_method.fused_experts is not None + and self.quant_method.fused_experts.output_is_reduced() ) def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): """ - The pplx combine kernel reduces across GPU ranks by default. + Some combine kernels reduce across GPU ranks by default. """ - if ( - self.use_pplx_kernels - or self.use_deepep_ht_kernels - or self.use_deepep_ll_kernels - ): + if self.must_reduce_shared_expert_outputs(): return final_hidden_states else: return tensor_model_parallel_all_reduce(final_hidden_states) @@ -2010,7 +2175,7 @@ def forward_native( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: og_hidden_states = hidden_states.shape[-1] if self.hidden_size != og_hidden_states: hidden_states = F.pad( @@ -2051,14 +2216,15 @@ def forward_cuda( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: return self.forward_native(hidden_states, router_logits) def forward_impl_chunked( self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + has_separate_shared_experts: bool, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.batched_hidden_states is not None assert self.batched_router_logits is not None assert self.batched_hidden_states.dtype == full_hidden_states.dtype @@ -2067,8 +2233,6 @@ def forward_impl_chunked( assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1) assert self.batched_router_logits.size(-1) == full_router_logits.size(-1) - self.ensure_moe_quant_config() - full_fused_final_hidden_states = torch.empty_like(full_hidden_states) if self.shared_experts is not None: full_shared_final_hidden_states = torch.empty_like(full_hidden_states) @@ -2106,11 +2270,23 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): # If there are shared experts but we are not using a modular kernel, # the shared experts must be called here - if ( - not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) - and self.shared_experts is not None - ): - shared_output = self.shared_experts(staged_hidden_states) + if has_separate_shared_experts: + assert self.shared_experts is not None + + if self.shared_experts_stream is not None: + # For chunked, we start the shared experts stream here + # (Note that no concurrency with the router/gate) + self.shared_experts_stream.wait_stream(current_stream()) + + with torch.cuda.stream(self.shared_experts_stream): + # Note that staged_hidden_states clone() is necessary + # here to avoid conflict with the main stream + shared_output = self.shared_experts( + staged_hidden_states.clone() + ) + else: + shared_output = self.shared_experts(staged_hidden_states) + else: shared_output = None @@ -2123,7 +2299,9 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, global_num_experts=self.global_num_experts, - expert_map=self.expert_map, + expert_map=self.expert_map + if not is_rocm_aiter_moe_enabled() + else self.expert_mask, topk_group=self.topk_group, num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, @@ -2137,9 +2315,14 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): logical_replica_count=self.logical_replica_count, ) - if shared_output is not None: + if has_separate_shared_experts: assert not isinstance(final_hidden_states, tuple) assert self.shared_experts is not None + + # Here we finish the shared experts stream + if self.shared_experts_stream is not None: + current_stream().wait_stream(self.shared_experts_stream) + final_hidden_states = ( shared_output, final_hidden_states, @@ -2204,37 +2387,57 @@ def forward_impl( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.quant_method is not None - self.ensure_moe_quant_config() + self.ensure_moe_quant_config_init() + self.ensure_dp_chunking_init() - # Route to the chunked forward path using the FlashInfer Cutlass kernel - # only when data parallelism (DP) is enabled. - _use_flashinfer_cutlass_kernels = ( - self.dp_size > 1 and self.use_flashinfer_cutlass_kernels + has_separate_shared_experts = ( + not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) + and self.shared_experts is not None ) + use_chunked_impl = self.use_dp_chunking + if ( - self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels - or _use_flashinfer_cutlass_kernels + has_separate_shared_experts + and not use_chunked_impl + and self.shared_experts_stream is not None ): - return self.forward_impl_chunked(hidden_states, router_logits) + # Start the separate shared experts stream here since we want + # to run in parallel with the router/gate (next op below) + self.shared_experts_stream.wait_stream(current_stream()) + + # If router/gate provided, then apply it here. + # (Note: This code runs only when "overlapped mode" is on to allow + # parallel execution of shared experts with the FusedMoE via + # separate cuda stream) + if self.gate is not None: + router_logits, _ = self.gate(hidden_states) + + if use_chunked_impl: + return self.forward_impl_chunked( + hidden_states, router_logits, has_separate_shared_experts + ) do_naive_dispatch_combine: bool = ( - self.dp_size > 1 - and not self.moe_parallel_config.use_deepep_ht_kernels - and not self.moe_config.use_flashinfer_cutlass_kernels + self.dp_size > 1 and not self.quant_method.using_modular_kernel ) # If there are shared experts but we are not using a modular kernel, the # shared experts must be called here - if ( - not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) - and self.shared_experts is not None - ): - shared_output = self.shared_experts(hidden_states) + if has_separate_shared_experts: + assert self.shared_experts is not None + + if self.shared_experts_stream is not None: + # Run shared experts in parallel on a separate stream + with torch.cuda.stream(self.shared_experts_stream): + # Note that hidden_states clone() is necessary here to avoid + # conflict with the main stream + shared_output = self.shared_experts(hidden_states.clone()) + else: + shared_output = self.shared_experts(hidden_states) else: shared_output = None @@ -2260,7 +2463,9 @@ def forward_impl( renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, global_num_experts=self.global_num_experts, - expert_map=self.expert_map, + expert_map=self.expert_map + if not is_rocm_aiter_moe_enabled() + else self.expert_mask, topk_group=self.topk_group, num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, @@ -2275,9 +2480,14 @@ def forward_impl( logical_replica_count=self.logical_replica_count, ) - if shared_output is not None: + if has_separate_shared_experts: assert not isinstance(final_hidden_states, tuple) assert self.shared_experts is not None + + # Wait for the parallel shared experts stream to finish here + if self.shared_experts_stream is not None: + current_stream().wait_stream(self.shared_experts_stream) + final_hidden_states = ( shared_output, final_hidden_states, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 1f6209c9d08e..8514b63556ae 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass from enum import Enum from math import prod -from typing import Callable, Optional, Union, final +from typing import final import torch @@ -13,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, count_expert_num_tokens, + disable_inplace, ) from vllm.utils import cdiv from vllm.v1.worker.ubatching import ( @@ -80,7 +82,7 @@ class ExpertTokensMetadata: """ expert_num_tokens: torch.Tensor - expert_num_tokens_cpu: Optional[torch.Tensor] + expert_num_tokens_cpu: torch.Tensor | None @staticmethod def make_from_list( @@ -103,7 +105,7 @@ class TopKWeightAndReduce(ABC): @abstractmethod def apply( self, - output: Optional[torch.Tensor], + output: torch.Tensor | None, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, @@ -131,10 +133,10 @@ def apply( # PrepareResultType = tuple[ torch.Tensor, - Optional[torch.Tensor], - Optional[ExpertTokensMetadata], - Optional[torch.Tensor], - Optional[torch.Tensor], + torch.Tensor | None, + ExpertTokensMetadata | None, + torch.Tensor | None, + torch.Tensor | None, ] ReceiverType = Callable[[], PrepareResultType] @@ -154,7 +156,7 @@ def prepare( topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], + expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> PrepareResultType: @@ -194,10 +196,10 @@ def prepare_async( topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], + expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> Union[tuple[Callable, ReceiverType], ReceiverType]: + ) -> tuple[Callable, ReceiverType] | ReceiverType: """ Perform any quantization (and/or) dispatching needed for this kernel but do not wait for results from other workers. @@ -269,7 +271,7 @@ def finalize_async( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: TopKWeightAndReduce, - ) -> Union[tuple[Callable, Callable], Callable]: + ) -> tuple[Callable, Callable] | Callable: """ Perform any combine plus apply weights and perform a reduction on the fused experts output but do not wait for results from other workers. @@ -313,7 +315,7 @@ def activation_format(self) -> FusedMoEActivationFormat: raise NotImplementedError @abstractmethod - def topk_indices_dtype(self) -> Optional[torch.dtype]: + def topk_indices_dtype(self) -> torch.dtype | None: """ The PrepareFinalize All2All implementations generally constrain the dtype of the topk_ids they support. This function returns the @@ -323,7 +325,7 @@ def topk_indices_dtype(self) -> Optional[torch.dtype]: raise NotImplementedError @abstractmethod - def max_num_tokens_per_rank(self) -> Optional[int]: + def max_num_tokens_per_rank(self) -> int | None: """ Some PrepareFinalize All2All implementations are batched. Meaning, they can process only as set of tokens at a time. This @@ -337,6 +339,14 @@ def max_num_tokens_per_rank(self) -> Optional[int]: def num_dispatchers(self) -> int: raise NotImplementedError + @abstractmethod + def output_is_reduced(self) -> bool: + """ + Indicates whether or not the output of finalize is reduced across all + ranks. + """ + raise NotImplementedError + # TODO: add supported activations method (return string) class FusedMoEPermuteExpertsUnpermute(ABC): @@ -414,11 +424,11 @@ def moe_problem_size( # @property - def quant_dtype(self) -> Optional[torch.dtype]: + def quant_dtype(self) -> torch.dtype | None: return self.quant_config.quant_dtype @property - def block_shape(self) -> Optional[list[int]]: + def block_shape(self) -> list[int] | None: return self.quant_config.block_shape @property @@ -430,51 +440,51 @@ def per_out_ch_quant(self) -> bool: return self.quant_config.per_out_ch_quant @property - def a1_scale(self) -> Optional[torch.Tensor]: + def a1_scale(self) -> torch.Tensor | None: return self.quant_config.a1_scale @property - def a2_scale(self) -> Optional[torch.Tensor]: + def a2_scale(self) -> torch.Tensor | None: return self.quant_config.a2_scale @property - def a1_gscale(self) -> Optional[torch.Tensor]: + def a1_gscale(self) -> torch.Tensor | None: return self.quant_config.a1_gscale @property - def a2_gscale(self) -> Optional[torch.Tensor]: + def a2_gscale(self) -> torch.Tensor | None: return self.quant_config.a2_gscale @property - def w1_scale(self) -> Optional[torch.Tensor]: + def w1_scale(self) -> torch.Tensor | None: return self.quant_config.w1_scale @property - def w2_scale(self) -> Optional[torch.Tensor]: + def w2_scale(self) -> torch.Tensor | None: return self.quant_config.w2_scale @property - def w1_zp(self) -> Optional[torch.Tensor]: + def w1_zp(self) -> torch.Tensor | None: return self.quant_config.w1_zp @property - def w2_zp(self) -> Optional[torch.Tensor]: + def w2_zp(self) -> torch.Tensor | None: return self.quant_config.w2_zp @property - def w1_bias(self) -> Optional[torch.Tensor]: + def w1_bias(self) -> torch.Tensor | None: return self.quant_config.w1_bias @property - def w2_bias(self) -> Optional[torch.Tensor]: + def w2_bias(self) -> torch.Tensor | None: return self.quant_config.w2_bias @property - def g1_alphas(self) -> Optional[torch.Tensor]: + def g1_alphas(self) -> torch.Tensor | None: return self.quant_config.g1_alphas @property - def g2_alphas(self) -> Optional[torch.Tensor]: + def g2_alphas(self) -> torch.Tensor | None: return self.quant_config.g2_alphas # TODO (bnell): make this return a CHUNK_SIZE or None instead? @@ -493,34 +503,49 @@ def supports_expert_map(self) -> bool: """ raise NotImplementedError + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: + """ + Workspace type: The dtype to use for the workspace tensors. + """ + return act_dtype + @abstractmethod def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + expert_tokens_meta: ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: """ Compute the shapes for the temporary and final outputs of the two gemms and activation in the fused expert function. Since the gemms are independent, the workspace for the first gemm can be shared with the workspace for the last gemm. + Inputs: + - M: number of tokens. + - N: Row (or column) dimension of expert weights. + - K: hidden dimension + - topk: The number of top-k experts to select. + - global_num_experts: global number of experts. + - local_num_experts: local number of experts due to DP/EP. + - expert_tokens_meta: number of tokens per expert metadata for batched + format. + Returns a tuple of: - workspace13 shape tuple: must be large enough to hold the result of either expert gemm. - workspace2 shape tuple: must be large enough to hold the result of the activation function. - output shape tuple: must be exact size of the final gemm output. - - Workspace type: The dtype to use for the workspace tensors. - - Note: in order for activation chunking to work, the first dimension - of each tuple must be the number of tokens. + - Note: workspace shapes can be 0 if the workspace is not needed. + But in order for activation chunking to work, the first dimension + of each tuple must be the number of tokens when the shape is + not 0. """ raise NotImplementedError @@ -532,6 +557,9 @@ def activation( torch.ops._C.silu_and_mul(output, input) elif activation == "gelu": torch.ops._C.gelu_and_mul(output, input) + elif activation == "swigluoai": + # alpha = 1.702, limit = 7.0 + torch.ops._C.swigluoai_and_mul(output, input) else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") @@ -554,14 +582,14 @@ def apply( topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_tokens_meta: Optional[ExpertTokensMetadata], + expert_tokens_meta: ExpertTokensMetadata | None, apply_router_weight_on_input: bool, - ): + ) -> None: """ This function computes the intermediate result of a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2. @@ -600,9 +628,9 @@ def apply( raise NotImplementedError -def _chunk_scales( - scales: Optional[torch.Tensor], start: int, end: int -) -> Optional[torch.Tensor]: +def _slice_scales( + scales: torch.Tensor | None, start: int, end: int +) -> torch.Tensor | None: if scales is not None: if scales.numel() == 1: return scales @@ -615,9 +643,10 @@ class SharedResizableBuffer: def __init__(self): self.buffer = None - def get(self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype): - if shape == () or shape is None: - return None + def get( + self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype + ) -> torch.Tensor: + assert shape != () shape_numel = prod(shape) if ( self.buffer is None @@ -663,7 +692,7 @@ def __init__( self, prepare_finalize: FusedMoEPrepareAndFinalize, fused_experts: FusedMoEPermuteExpertsUnpermute, - shared_experts: Optional[torch.nn.Module] = None, + shared_experts: torch.nn.Module | None = None, ): super().__init__() self.prepare_finalize = prepare_finalize @@ -678,131 +707,77 @@ def __init__( f"{fused_experts.activation_formats[0]}" ) - def _do_fused_experts( + def output_is_reduced(self) -> bool: + """ + Indicates whether or not the output of fused MoE kernel + is reduced across all ranks. + """ + return self.prepare_finalize.output_is_reduced() + + def _chunk_info(self, M: int) -> tuple[int, int]: + """ + Compute number of chunks and chunk size for given M. + If chunking is not supported, set the CHUNK_SIZE to M so we + get num_chunks == 1. Take max(M, 1) to avoid divide by zero. + If there are no tokens to process, the number of chunks will be zero. + """ + CHUNK_SIZE = max( + 1, + ( + M + if not self.fused_experts.supports_chunking() + else min(M, envs.VLLM_FUSED_MOE_CHUNK_SIZE) + ), + ) + num_chunks = cdiv(M, CHUNK_SIZE) + # If there are no tokens, then there should be no loop iterations. + assert M > 0 or num_chunks == 0 + return num_chunks, CHUNK_SIZE + + def _allocate_buffers( self, - fused_out: Optional[torch.Tensor], - a1: torch.Tensor, - a1q: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, + out_dtype: torch.dtype, + device: torch.device, + M_chunk: int, + M_full: int, + N: int, + K: int, + top_k: int, global_num_experts: int, local_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - expert_tokens_meta: Optional[ExpertTokensMetadata], - apply_router_weight_on_input: bool, - ) -> torch.Tensor: - _, M, N, K, top_k = self.fused_experts.moe_problem_size(a1q, w1, w2, topk_ids) + expert_tokens_meta: ExpertTokensMetadata | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Allocate temporary and output buffers for the fused experts op. + Inputs: + - out_dtype: output type of workspace and output tensors. + - device: the device of the workspace and output tensors. + See `workspace_shapes` for a description of the remainder of arguments. + Returns a tuple of (workspace13, workspace2, output) tensors. + """ + assert M_full > 0 and M_chunk > 0 - (workspace13_shape, workspace2_shape, fused_out_shape, workspace_dtype) = ( - self.fused_experts.workspace_shapes( - a1, - a1q, - M, - N, - K, - top_k, - global_num_experts, - local_num_experts, - expert_tokens_meta, - ) - ) + num_chunks, _ = self._chunk_info(M_full) # select per-ubatch buffers to avoid cross-ubatch reuse under DBO ubatch_idx = dbo_current_ubatch_id() buffers = self.shared_buffers[ubatch_idx] + workspace_dtype = self.fused_experts.workspace_dtype(out_dtype) - # We can reuse the memory between cache1 and cache3 because by the - # time we need cache3, we're done with cache1. - workspace13 = buffers.workspace13.get( - workspace13_shape, device=a1.device, dtype=workspace_dtype - ) - workspace2 = buffers.workspace2.get( - workspace2_shape, device=a1.device, dtype=workspace_dtype - ) - - assert fused_out is None or fused_out.shape == fused_out_shape, ( - f"fused_out {fused_out.shape} but expected {fused_out_shape}" - ) - if fused_out is None: - # reuse workspace13 for the output - fused_out = _resize_cache(workspace13, fused_out_shape) - - self.fused_experts.apply( - fused_out, - a1q, - w1, - w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - a1q_scale=a1q_scale, - a2_scale=a2_scale, - workspace13=workspace13, - workspace2=workspace2, - expert_tokens_meta=expert_tokens_meta, - apply_router_weight_on_input=apply_router_weight_on_input, + # Get intermediate workspace shapes based off the chunked M size. + workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes( + M_chunk, + N, + K, + top_k, + global_num_experts, + local_num_experts, + expert_tokens_meta, ) - return fused_out - - def _maybe_chunk_fused_experts( - self, - a1: torch.Tensor, - a1q: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - local_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - expert_tokens_meta: Optional[ExpertTokensMetadata], - apply_router_weight_on_input: bool, - ) -> torch.Tensor: - _, M, N, K, top_k = self.fused_experts.moe_problem_size(a1q, w1, w2, topk_ids) - - CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE - num_chunks = cdiv(M, CHUNK_SIZE) - - # TODO(bnell): get rid of one level here, update slice functions - # to nops on num_chunks==1 - - if not self.fused_experts.supports_chunking() or num_chunks == 1: - return self._do_fused_experts( - fused_out=None, - a1=a1, - a1q=a1q, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=activation, - global_num_experts=global_num_experts, - local_num_experts=local_num_experts, - expert_map=expert_map, - a1q_scale=a1q_scale, - a2_scale=self.fused_experts.a2_scale, - expert_tokens_meta=expert_tokens_meta, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - - # Chunking required case - assert num_chunks > 1 - - # Construct the entire output that can then be processed in chunks. - (_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes( - a1, - a1q, - M, + # Get final output shape based on the full M size. + _, _, fused_out_shape = self.fused_experts.workspace_shapes( + M_full, N, K, top_k, @@ -810,150 +785,99 @@ def _maybe_chunk_fused_experts( local_num_experts, expert_tokens_meta, ) - ubatch_idx = dbo_current_ubatch_id() - buffers = self.shared_buffers[ubatch_idx] - fused_out = buffers.fused_out.get( - fused_out_shape, device=a1q.device, dtype=a1.dtype + + # We can reuse the memory between cache1 and cache3 because by the + # time we need cache3, we're done with cache1. + workspace13 = buffers.workspace13.get( + workspace13_shape, device=device, dtype=workspace_dtype + ) + workspace2 = buffers.workspace2.get( + workspace2_shape, device=device, dtype=workspace_dtype ) - def slice_input_tensors( - chunk_idx: int, - ) -> tuple[ - torch.Tensor, - Optional[torch.Tensor], - Optional[torch.Tensor], - torch.Tensor, - torch.Tensor, - ]: - s = chunk_idx * CHUNK_SIZE - e = min(s + CHUNK_SIZE, M) - return ( - a1q[s:e], - _chunk_scales(a1q_scale, s, e), - _chunk_scales(self.fused_experts.a2_scale, s, e), - topk_ids[s:e], - topk_weights[s:e], + # Construct the entire output that can then be processed in chunks. + # Reuse workspace13 for the output in the non-chunked case as long + # as it is large enough. This will not always be the case for standard + # format experts and with experts that have empty workspaces. + if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape): + fused_out = _resize_cache(workspace13, fused_out_shape) + else: + fused_out = buffers.fused_out.get( + fused_out_shape, device=device, dtype=out_dtype ) - def slice_output_tensor(chunk_idx: int) -> torch.Tensor: - assert fused_out.size(0) % M == 0, ( - f"fused_out shape {fused_out.shape} vs M {M}" - ) - factor = fused_out.size(0) // M - out_chunk_size = CHUNK_SIZE * factor - s = chunk_idx * out_chunk_size - e = min(s + out_chunk_size, fused_out.size(0)) - return fused_out[s:e] - - def slice_expert_tokens_metadata( - full_expert_tokens_meta: ExpertTokensMetadata, - chunk_topk_ids: torch.Tensor, - local_num_experts: int, - expert_map: Optional[torch.Tensor], - ) -> ExpertTokensMetadata: - # The existing expert_num_tokens is for the entire a1q - # input. Chunking forces recomputation of the number - # of tokens assigned to each expert. - c_expert_num_tokens = count_expert_num_tokens( - chunk_topk_ids, local_num_experts, expert_map - ) + return workspace13, workspace2, fused_out - c_expert_num_tokens_cpu = None - need_expert_num_tokens_cpu = ( - full_expert_tokens_meta.expert_num_tokens_cpu is not None - ) - if need_expert_num_tokens_cpu: - # This is blocking as some implementations need the count - # on the CPU to determine appropriate input/out fused-moe - # buffers - c_expert_num_tokens_cpu = c_expert_num_tokens.to( - "cpu", non_blocking=False - ) - - return ExpertTokensMetadata( - expert_num_tokens=c_expert_num_tokens, - expert_num_tokens_cpu=c_expert_num_tokens_cpu, - ) + @staticmethod + def _slice_output_tensor( + fused_out: torch.Tensor, + chunk_idx: int, + num_chunks: int, + CHUNK_SIZE: int, + M: int, + ) -> torch.Tensor: + if num_chunks == 1: + return fused_out - for chunk_idx in range(num_chunks): - c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = ( - slice_input_tensors(chunk_idx) - ) + assert fused_out.size(0) % M == 0, f"fused_out shape {fused_out.shape} vs M {M}" + factor = fused_out.size(0) // M + out_chunk_size = CHUNK_SIZE * factor + s = chunk_idx * out_chunk_size + e = min(s + out_chunk_size, fused_out.size(0)) + return fused_out[s:e] - c_expert_tokens_meta = None - if expert_tokens_meta is not None: - c_expert_tokens_meta = slice_expert_tokens_metadata( - expert_tokens_meta, c_topk_ids, local_num_experts, expert_map - ) + @staticmethod + def _slice_expert_tokens_metadata( + num_chunks: int, + full_expert_tokens_meta: ExpertTokensMetadata | None, + chunk_topk_ids: torch.Tensor, + local_num_experts: int, + expert_map: torch.Tensor | None, + ) -> ExpertTokensMetadata | None: + if num_chunks == 1 or full_expert_tokens_meta is None: + return full_expert_tokens_meta + + # The existing expert_num_tokens is for the entire a1q + # input. Chunking forces recomputation of the number + # of tokens assigned to each expert. + c_expert_num_tokens = count_expert_num_tokens( + chunk_topk_ids, local_num_experts, expert_map + ) - self._do_fused_experts( - fused_out=slice_output_tensor(chunk_idx), - a1=a1, - a1q=c_a1q, - w1=w1, - w2=w2, - topk_weights=c_topk_weights, - topk_ids=c_topk_ids, - activation=activation, - global_num_experts=global_num_experts, - local_num_experts=local_num_experts, - expert_map=expert_map, - a1q_scale=c_a1q_scale, - a2_scale=c_a2_scale, - expert_tokens_meta=c_expert_tokens_meta, - apply_router_weight_on_input=apply_router_weight_on_input, - ) + c_expert_num_tokens_cpu = None + need_expert_num_tokens_cpu = ( + full_expert_tokens_meta.expert_num_tokens_cpu is not None + ) + if need_expert_num_tokens_cpu: + # This is blocking as some implementations need the count + # on the CPU to determine appropriate input/out fused-moe + # buffers + c_expert_num_tokens_cpu = c_expert_num_tokens.to("cpu", non_blocking=False) - return fused_out + return ExpertTokensMetadata( + expert_num_tokens=c_expert_num_tokens, + expert_num_tokens_cpu=c_expert_num_tokens_cpu, + ) - def forward( + def _prepare( self, hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + global_num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + ExpertTokensMetadata | None, + torch.Tensor, + torch.Tensor, + ]: """ - This function computes a Mixture of Experts (MoE) layer using two sets - of weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states: (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - topk_weights (torch.Tensor): The topk weights applied at the end of - the layer. - - topk_ids (torch.Tensor): A map of row to expert id. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - activation (str): The activation function to apply after the first - MoE layer. - - global_num_experts (int): The total number of experts in the global - expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert - parallel shard. - - apply_router_weight_on_input (bool): When true, the topk weights are - applied directly on the inputs. This is only applicable when topk is - 1. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. + The _prepare method is a wrapper around self.prepare_finalize.prepare + that handles DBO and async. """ - - a1 = hidden_states - output = a1 if inplace and self.shared_experts is None else torch.zeros_like(a1) - - local_num_experts = w1.size(0) - if global_num_experts == -1: - global_num_experts = local_num_experts - if not self.prepare_finalize.supports_async(): # We shouldn't be running an a2a kernel that doesn't # support async prepare/finalize @@ -967,7 +891,7 @@ def forward( _expert_topk_ids, _expert_topk_weights, ) = self.prepare_finalize.prepare( - a1, + hidden_states, topk_weights, topk_ids, global_num_experts, @@ -979,7 +903,7 @@ def forward( # Overlap shared expert compute with all2all dispatch. dbo_maybe_run_recv_hook() prepare_ret = self.prepare_finalize.prepare_async( - a1, + hidden_states, topk_weights, topk_ids, global_num_experts, @@ -1019,34 +943,115 @@ def forward( topk_weights if _expert_topk_weights is None else _expert_topk_weights ) - fused_out = None + return a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights - if a1q.numel() == 0: - # This happens when none of the tokens from the all2all reach this - # EP rank. Also, note that this is only relevant for CUDAGraph - # incompatible all2all kernels like the DeepEP high-throughput - # kernels. CUDAGraph compatible all2all kernels like the pplx - # kernels and the DeepEP low-latency kernels are always batched - # and can never run into the tensor.numel() == 0 case. - fused_out = torch.empty_like(a1q).to(dtype=a1.dtype) + def _fused_experts( + self, + in_dtype: torch.dtype, + a1q: torch.Tensor, + a1q_scale: torch.Tensor | None, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + local_num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + expert_tokens_meta: ExpertTokensMetadata | None, + ) -> torch.Tensor: + _, M_full, N, K, top_k = self.fused_experts.moe_problem_size( + a1q, w1, w2, topk_ids + ) + + num_chunks, CHUNK_SIZE = self._chunk_info(M_full) + + def input_chunk_range(chunk_idx: int) -> tuple[int, int]: + if num_chunks == 1: + # Use a1q.size(0) here since batched format does not + # keep M in the first dimension. + return 0, a1q.size(0) + else: + s = chunk_idx * CHUNK_SIZE + e = min(s + CHUNK_SIZE, M_full) + return s, e + + # This happens when none of the tokens from the all2all reach this + # EP rank. Also, note that this is only relevant for CUDAGraph + # incompatible all2all kernels like the DeepEP high-throughput + # kernels. CUDAGraph compatible all2all kernels like the pplx + # kernels and the DeepEP low-latency kernels are always batched + # and can never run into the tensor.numel() == 0 case. + if M_full == 0: + assert num_chunks == 0 + workspace13 = None + workspace2 = None + fused_out = torch.empty_like(a1q, dtype=in_dtype) else: - fused_out = self._maybe_chunk_fused_experts( - a1=a1, - a1q=a1q, + assert num_chunks > 0 + workspace13, workspace2, fused_out = self._allocate_buffers( + in_dtype, + a1q.device, + CHUNK_SIZE, + M_full, + N, + K, + top_k, + global_num_experts, + local_num_experts, + expert_tokens_meta, + ) + + for chunk_idx in range(num_chunks): + s, e = input_chunk_range(chunk_idx) + + c_expert_tokens_meta = self._slice_expert_tokens_metadata( + num_chunks, + expert_tokens_meta, + topk_ids[s:e], + local_num_experts, + expert_map, + ) + + c_fused_out = self._slice_output_tensor( + fused_out, chunk_idx, num_chunks, CHUNK_SIZE, M_full + ) + + self.fused_experts.apply( + output=c_fused_out, + hidden_states=a1q[s:e], w1=w1, w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, + topk_weights=topk_weights[s:e], + topk_ids=topk_ids[s:e], activation=activation, global_num_experts=global_num_experts, - local_num_experts=local_num_experts, expert_map=expert_map, - a1q_scale=a1q_scale, - expert_tokens_meta=expert_tokens_meta, + a1q_scale=_slice_scales(a1q_scale, s, e), + a2_scale=_slice_scales(self.fused_experts.a2_scale, e, e), + workspace13=workspace13, + workspace2=workspace2, + expert_tokens_meta=c_expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, ) - shared_output: Optional[torch.Tensor] = None + return fused_out + + def _finalize( + self, + output: torch.Tensor, + fused_out: torch.Tensor, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + The _finalize method is a wrapper around self.prepare_finalize.finalize + that handles DBO, async and shared expert overlap. + """ + shared_output: torch.Tensor | None = None if not self.prepare_finalize.supports_async(): assert not dbo_enabled() @@ -1060,7 +1065,7 @@ def forward( self.fused_experts.finalize_weight_and_reduce_impl(), ) if self.shared_experts is not None: - shared_output = self.shared_experts(a1) + shared_output = self.shared_experts(hidden_states) else: finalize_ret = self.prepare_finalize.finalize_async( output, @@ -1072,7 +1077,7 @@ def forward( ) if self.shared_experts is not None: - shared_output = self.shared_experts(a1) + shared_output = self.shared_experts(hidden_states) # TODO(lucas): refactor this in the alternative schedules followup # currently unpack if we have hook + receiver pair or just @@ -1100,3 +1105,87 @@ def forward( else: assert shared_output is not None return shared_output, output + + def forward( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + This function computes a Mixture of Experts (MoE) layer using two sets + of weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states: (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - topk_weights (torch.Tensor): The topk weights applied at the end of + the layer. + - topk_ids (torch.Tensor): A map of row to expert id. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - activation (str): The activation function to apply after the first + MoE layer. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - apply_router_weight_on_input (bool): When true, the topk weights are + applied directly on the inputs. This is only applicable when topk is + 1. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + + if inplace and self.shared_experts is None and not disable_inplace(): + output = hidden_states + else: + output = torch.zeros_like(hidden_states) + + local_num_experts = w1.size(0) + if global_num_experts == -1: + global_num_experts = local_num_experts + + a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights = self._prepare( + hidden_states, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + ) + + fused_out = self._fused_experts( + in_dtype=hidden_states.dtype, + a1q=a1q, + a1q_scale=a1q_scale, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + global_num_experts=global_num_experts, + local_num_experts=local_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_tokens_meta=expert_tokens_meta, + ) + + return self._finalize( + output, + fused_out, + hidden_states, + topk_weights, + topk_ids, + apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py index 9994088ca5d9..f4d8a86c058a 100644 --- a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -13,7 +12,7 @@ def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int, - expert_map: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, pad_sorted_ids: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ @@ -84,3 +83,92 @@ def moe_align_block_size( expert_ids = expert_map[expert_ids] return sorted_ids, expert_ids, num_tokens_post_pad + + +def batched_moe_align_block_size( + max_tokens_per_batch: int, block_size: int, expert_num_tokens: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Given num_batches, max_tokens_per_batch, block_size and the number of + valid-tokens in each batch, prepare sorted_token_ids, expert_ids and + num_tokens_post_pad. sorted_token_ids, expert_ids and num_tokens_post_pad + have the same semantics as in moe_align_block_size. + + This function is intended to be a drop in replacement for + moe_align_batch_size for the batched case. + + Parameters: + - max_tokens_per_batch (int): Number of tokens in each batch (both + valid and invalid). + - block_size (int): block_size to align the data to. + - expert_num_tokens (torch.Tensor): expert_num_tokens[i], indicates + the number of valid tokens in batch i. + + Returns: + - sorted_token_ids (torch.Tensor): Torch tensor of size + (num_batches * max_tokens_per_batch) indicating the token indices for + that block. + - expert_ids (torch.Tensor): Torch tensor of size + ceil((num_batches * max_tokens_per_batch) / block_size) indicating + what expert to use for each block. + - num_tokens_post_pad (torch.Tensor): Torch tensor of size 1 + indicating the number of valid blocks with actual data to + process. This is represented in terms of num tokens. + Example: + Let num_batches=5, max_tokens_per_batch=8, block_size=4, and + expert_num_tokens=[2, 3, 0, 6, 8]. This expert_num_tokens tensor + indicates that, + - The first 2 tokens in the 0th batch are valid and the rest 6 are + invalid (i.e. in the 2D hidden_states tensor of shape, + [num_batches * max_tokens_per_batch, K], indices 0, 1 are valid) + - The first 3 tokens in the 1st batch are valid. i.e. indices 8, 9, 10 + - 0 tokens in the 2nd batch are valid + - first 6 tokens in the 3rd batch are valid. i.e. indices, + 24, 25, 26, 27, 28, 29 + - so on ... + + In this case, + sorted_token_ids will be [0, 1, 40, 40, + 8, 9, 10, 40, + 24, 25, 26, 27, + 28, 29, 40, 40, + 32, 33, 34, 35, + 36, 37, 38, 39, + 40, 40, 40, 40, + (rest all 40, 40, 40, 40) + ...] + Here, 40 represents an invalid index. as there is no token index 40. + The gemm kernel using this sorted_token_ids is expected to skip the + gemm computation when it encounters this invalid index. + + expert_ids will be [0, 1, 3, 3, 4, 5, 5, -1, -1, (rest all -1) ...] + Here, -1 represents an invalid expert. The gemm kernel using this + expert_ids is expected to skip the gemm computation when it encounters + an expert of id -1. + + num_tokens_post_pad will be 24 as sorted_token_ids has valid entries + until 24. + """ + + B = expert_num_tokens.size(0) + device = expert_num_tokens.device + + # Round up so each batch can be split to blocks evenly. + max_num_tokens_padded = B * round_up(max_tokens_per_batch, block_size) + + sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device) + assert max_num_tokens_padded % block_size == 0 + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device=device) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=device) + + ops.batched_moe_align_block_size( + max_tokens_per_batch, + block_size, + expert_num_tokens, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + + return sorted_ids, expert_ids, num_tokens_post_pad diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index 698080f8aec6..9dcdcc380036 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -13,14 +12,12 @@ def _moe_permute( curr_hidden_states: torch.Tensor, - a1q_scale: Optional[torch.Tensor], + a1q_scale: torch.Tensor | None, curr_topk_ids: torch.Tensor, global_num_experts: int, - expert_map: Optional[torch.Tensor], + expert_map: torch.Tensor | None, block_m: int, -) -> tuple[ - torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor -]: +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]: """ Determine the sorted_token_ids, expert_ids for the given problem size. Permute the hidden states and scales according to `sorted_token_ids`. @@ -33,7 +30,7 @@ def _moe_permute( curr_topk_ids, block_m, global_num_experts, expert_map, pad_sorted_ids=True ) - inv_perm: Optional[torch.Tensor] = None + inv_perm: torch.Tensor | None = None num_tokens = top_k_num * tokens_in_chunk expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0) @@ -53,7 +50,7 @@ def _moe_permute( def _moe_unpermute_and_reduce( out: torch.Tensor, curr_hidden: torch.Tensor, - inv_perm: Optional[torch.Tensor], + inv_perm: torch.Tensor | None, topk_weight: torch.Tensor, apply_router_weight_on_input: bool, ) -> None: @@ -73,17 +70,15 @@ def _moe_unpermute_and_reduce( def moe_permute( hidden_states: torch.Tensor, - a1q_scale: Optional[torch.Tensor], + a1q_scale: torch.Tensor | None, topk_ids: torch.Tensor, n_expert: int, n_local_expert: int = -1, - expert_map: Optional[torch.Tensor] = None, - align_block_size: Optional[int] = None, + expert_map: torch.Tensor | None = None, + align_block_size: int | None = None, fill_invalid_expert: int = -1, - permuted_hidden_states: Optional[torch.Tensor] = None, -) -> tuple[ - torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor -]: + permuted_hidden_states: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]: """ This function expands and permutes activation to gather uncontinuous tokens for each expert. @@ -198,7 +193,7 @@ def moe_unpermute( permuted_hidden_states: torch.Tensor, topk_weights: torch.Tensor, inv_permuted_idx: torch.Tensor, - expert_first_token_offset: Optional[torch.Tensor] = None, + expert_first_token_offset: torch.Tensor | None = None, ) -> None: """ This function expands and permutes activation to gathering uncontinuous diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 79212c2b689d..0e77fa54cd50 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional, Union +from collections.abc import Callable import pplx_kernels as pplx import torch @@ -24,9 +24,9 @@ def pplx_hidden_dim_scale_bytes( max_num_tokens: int, hidden_dim: int, in_dtype: torch.dtype, - quant_dtype: Union[torch.dtype, str, None], + quant_dtype: torch.dtype | str | None, per_act_token_quant: bool, - block_shape: Optional[list[int]], + block_shape: list[int] | None, ): # All pplx byte sizes must be 16-byte aligned. align = 16 @@ -82,15 +82,18 @@ def __init__( def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.BatchedExperts - def max_num_tokens_per_rank(self) -> Optional[int]: + def max_num_tokens_per_rank(self) -> int | None: return self.max_num_tokens - def topk_indices_dtype(self) -> Optional[torch.dtype]: + def topk_indices_dtype(self) -> torch.dtype | None: return torch.uint32 def num_dispatchers(self) -> int: return self.num_dispatchers_ + def output_is_reduced(self) -> bool: + return True + def supports_async(self) -> bool: return True @@ -100,7 +103,7 @@ def prepare_async( topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], + expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> tuple[Callable, mk.ReceiverType]: @@ -145,7 +148,7 @@ def prepare_async( a1q, a1q_scale, quant_config.per_act_token_quant, quant_config.block_shape ) - orig_a_scale_block_shape: Optional[int] = None + orig_a_scale_block_shape: int | None = None if a1q_scale is not None: scalar_scales = a1q_scale.numel() == 1 @@ -181,7 +184,7 @@ def prepare_async( device=device, ) - expert_x_scale: Optional[torch.Tensor] = None + expert_x_scale: torch.Tensor | None = None if a1q.dtype.itemsize == 1: if quant_config.is_per_act_token: # (M x 1) -> (E x M x K) @@ -209,7 +212,7 @@ def prepare_async( # This argument is optional, defaults to indices.size(0) # There's not much point setting this unless it is != indices.size(0) - bound_m: Optional[torch.Tensor] = None + bound_m: torch.Tensor | None = None self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, @@ -249,8 +252,8 @@ def _receiver( self, expert_num_tokens: torch.Tensor, expert_x: torch.Tensor, - expert_x_scale: Optional[torch.Tensor], - orig_a_scale_block_shape: Optional[int], + expert_x_scale: torch.Tensor | None, + orig_a_scale_block_shape: int | None, ) -> mk.PrepareResultType: if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] @@ -268,7 +271,7 @@ def prepare( topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], + expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: @@ -299,7 +302,7 @@ def finalize_async( # This argument is optional # There's not much point setting this unless it is != topk_ids.size(0) - bound_m: Optional[torch.Tensor] = None + bound_m: torch.Tensor | None = None # TODO (bnell): fails in test_pplx_moe.py, figure out what's going on # num_tokens = output.size(0) # M diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index be6939a3f62f..9bb976fb9ec9 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -18,22 +17,25 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard - def max_num_tokens_per_rank(self) -> Optional[int]: + def max_num_tokens_per_rank(self) -> int | None: return None - def topk_indices_dtype(self) -> Optional[torch.dtype]: + def topk_indices_dtype(self) -> torch.dtype | None: return None def num_dispatchers(self) -> int: return 1 + def output_is_reduced(self) -> bool: + return False + def prepare( self, a1: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, - expert_map: Optional[torch.Tensor], + expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 801785b18fb9..e18514ad43f6 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import IntEnum -from functools import cache -from typing import Optional +from functools import cache, lru_cache import torch @@ -12,7 +11,7 @@ FusedMoEQuantConfig, ) from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op class QuantMethod(IntEnum): @@ -47,19 +46,87 @@ def is_rocm_aiter_moe_enabled() -> bool: ) +@cache +def use_mxfp4_aiter_moe() -> bool: + return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER + + +@cache +def is_rocm_aiter_fusion_shared_expert_enabled() -> bool: + return ( + envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS and is_rocm_aiter_moe_enabled() + ) + + +aiter_topK_meta_data = None + + +@lru_cache(maxsize=1) +def init_aiter_topK_meta_data( + n_routed_experts: int, + n_shared_experts: int, + top_k: int, + tp_rank: int, + tp_size: int, + shared_experts_score: float = 1.0, + max_num_tokens: int = 32768, + is_EP: bool = False, +): + global aiter_topK_meta_data + fake_expertid = n_routed_experts + n_shared_experts + + # all layers reuse same buffer + # This extra element when EP is enabled is used as a sentinel + # to mask out shared expert processing for tokens not owned by + # the current EP rank. This is necessary to avoid double-processing + # of shared experts. + total_topk_ids = torch.empty( + (max_num_tokens, top_k + n_shared_experts + is_EP), + dtype=torch.int32, + device="cuda", + ) + ns_topk_ids, s_topk_ids = total_topk_ids.split( + [top_k, n_shared_experts + is_EP], dim=1 + ) + shared_expert_ids = [n_routed_experts + i for i in range(n_shared_experts + is_EP)] + if is_EP: + s_topk_ids_list = [ + [fake_expertid] * (n_shared_experts + is_EP) + ] * max_num_tokens + for i in range(tp_rank, max_num_tokens, tp_size): + s_topk_ids_list[i] = shared_expert_ids + else: + s_topk_ids_list = [ + list(range(n_routed_experts, fake_expertid)) + ] * max_num_tokens + s_topk_ids[:] = torch.tensor(s_topk_ids_list, dtype=torch.int32, device="cuda") + + total_topk_weights = torch.empty( + (max_num_tokens, top_k + n_shared_experts + is_EP), + dtype=torch.float32, + device="cuda", + ) + ns_topk_weights, s_topk_weights = total_topk_weights.split( + [top_k, n_shared_experts + is_EP], dim=1 + ) + s_topk_weights.fill_(shared_experts_score) + assert aiter_topK_meta_data is None, "AITER topK meta data is already initialized" + aiter_topK_meta_data = (total_topk_weights, total_topk_ids) + + def rocm_aiter_asm_moe_tkw1_impl( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - fc1_scale: Optional[torch.Tensor] = None, - fc2_scale: Optional[torch.Tensor] = None, - fc1_smooth_scale: Optional[torch.Tensor] = None, - fc2_smooth_scale: Optional[torch.Tensor] = None, + fc1_scale: torch.Tensor | None = None, + fc2_scale: torch.Tensor | None = None, + fc1_smooth_scale: torch.Tensor | None = None, + fc2_smooth_scale: torch.Tensor | None = None, a16: bool = False, - per_tensor_quant_scale: Optional[torch.Tensor] = None, - expert_mask: Optional[torch.Tensor] = None, + per_tensor_quant_scale: torch.Tensor | None = None, + expert_mask: torch.Tensor | None = None, activation_method: int = ActivationMethod.SILU.value, ) -> torch.Tensor: from aiter import ActivationType @@ -90,13 +157,13 @@ def rocm_aiter_asm_moe_tkw1_fake( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - fc1_scale: Optional[torch.Tensor] = None, - fc2_scale: Optional[torch.Tensor] = None, - fc1_smooth_scale: Optional[torch.Tensor] = None, - fc2_smooth_scale: Optional[torch.Tensor] = None, + fc1_scale: torch.Tensor | None = None, + fc2_scale: torch.Tensor | None = None, + fc1_smooth_scale: torch.Tensor | None = None, + fc2_smooth_scale: torch.Tensor | None = None, a16: bool = False, - per_tensor_quant_scale: Optional[torch.Tensor] = None, - expert_mask: Optional[torch.Tensor] = None, + per_tensor_quant_scale: torch.Tensor | None = None, + expert_mask: torch.Tensor | None = None, activation_method: int = ActivationMethod.SILU.value, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -206,14 +273,14 @@ def rocm_aiter_fused_moe_impl( w2: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor, - expert_mask: Optional[torch.Tensor] = None, + expert_mask: torch.Tensor | None = None, activation_method: int = ActivationMethod.SILU.value, quant_method: int = QuantMethod.NO.value, doweight_stage1: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, ) -> torch.Tensor: from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe @@ -244,14 +311,14 @@ def rocm_aiter_fused_moe_fake( w2: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor, - expert_mask: Optional[torch.Tensor] = None, + expert_mask: torch.Tensor | None = None, activation_method: int = ActivationMethod.SILU.value, quant_method: int = QuantMethod.NO.value, doweight_stage1: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -300,12 +367,34 @@ def rocm_aiter_grouped_topk( topk_group: int = 0, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, + num_fused_shared_experts: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: token = hidden_states.shape[0] device = hidden_states.device - topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) - topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) + if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0: + assert aiter_topK_meta_data is not None, ( + "AITER topK meta data is not initialized. " + "Please ensure that init_aiter_topK_meta_data " + "is called before this function." + ) + total_topk_weights, total_topk_ids = aiter_topK_meta_data + assert total_topk_weights.shape[0] >= token, ( + f"AITER topK meta data support {total_topk_weights.shape[0]} " + f"tokens which is determined by max_num_batched_tokens, " + f"but got {token} tokens now." + ) + total_topk_weights = total_topk_weights[:token] + total_topk_ids = total_topk_ids[:token] + topk_weights, _ = total_topk_weights.split( + [topk, total_topk_weights.shape[1] - topk], dim=1 + ) + topk_ids, _ = total_topk_ids.split( + [topk, total_topk_ids.shape[1] - topk], dim=1 + ) + else: + topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) + topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) if e_score_correction_bias is not None: torch.ops.vllm.rocm_aiter_biased_grouped_topk( @@ -316,6 +405,7 @@ def rocm_aiter_grouped_topk( num_expert_group, topk_group, renormalize, + routed_scaling_factor=routed_scaling_factor, ) else: assert scoring_func == "softmax" or scoring_func == "sigmoid" @@ -327,10 +417,11 @@ def rocm_aiter_grouped_topk( topk_group, renormalize, scoring_func, + routed_scaling_factor=routed_scaling_factor, ) - if routed_scaling_factor != 1.0: - topk_weights = topk_weights * routed_scaling_factor + if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0: + return total_topk_weights, total_topk_ids return topk_weights, topk_ids @@ -342,8 +433,8 @@ def rocm_aiter_fused_experts( topk_ids: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, - expert_map: Optional[torch.Tensor] = None, - quant_config: Optional[FusedMoEQuantConfig] = None, + expert_map: torch.Tensor | None = None, + quant_config: FusedMoEQuantConfig | None = None, ) -> torch.Tensor: if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG @@ -355,7 +446,7 @@ def rocm_aiter_fused_experts( topk_weights = topk_weights.to(torch.float32) topk_ids = topk_ids.to(torch.int32) - expert_mask = (expert_map > -1).to(torch.int32) if expert_map is not None else None + expert_mask = expert_map if expert_map is not None else None # w8a8 per-channel quantization if ( @@ -401,6 +492,8 @@ def rocm_aiter_fused_experts( assert quant_config.w1_scale is not None assert quant_config.w2_scale is not None quant_method = QuantMethod.BLOCK_128x128.value + elif quant_config.use_fp8_w8a8 and quant_config.per_out_ch_quant: + quant_method = QuantMethod.PER_TOKEN.value elif quant_config.use_fp8_w8a8: # Currently only per tensor quantization method is enabled. quant_method = QuantMethod.PER_TENSOR.value diff --git a/vllm/model_executor/layers/fused_moe/routing_simulator.py b/vllm/model_executor/layers/fused_moe/routing_simulator.py index af20f4b7c1d2..8b04cf4539e0 100644 --- a/vllm/model_executor/layers/fused_moe/routing_simulator.py +++ b/vllm/model_executor/layers/fused_moe/routing_simulator.py @@ -10,7 +10,7 @@ """ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any import torch @@ -24,7 +24,7 @@ def route_tokens( hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int, - indices_type: Optional[torch.dtype] = None, + indices_type: torch.dtype | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Route tokens to experts. @@ -89,7 +89,7 @@ def route_tokens( hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int, - indices_type: Optional[torch.dtype] = None, + indices_type: torch.dtype | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Randomly select experts for each token using the specified distribution. @@ -269,7 +269,7 @@ def simulate_routing( router_logits: torch.Tensor, strategy_name: str, top_k: int, - indices_type: Optional[torch.dtype] = None, + indices_type: torch.dtype | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Simulate token-to-expert routing using the specified strategy. diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py new file mode 100644 index 000000000000..2db733b765ce --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + +# TODO(bnell): Add shared + fused combo function? e.g. + +class SharedFusedMoE(FusedMoE): + """ + A FusedMoE operation that also computes the results of shared experts. + If an all2all communicator is being used the shared expert computation + can be interleaved with the fused all2all dispatch communication step. + """ + + def __init__( + self, + shared_experts: torch.nn.Module | None, + gate: torch.nn.Module | None = None, + use_overlapped: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self._shared_experts = shared_experts + + # Disable shared expert overlap if EP is disabled or we are not using + # flashinfer + DP since there is nothing to be gained in this case. + # Disabling the overlap optimization also prevents the shared experts + # from being hidden from torch.compile. + self.use_overlapped = ( + use_overlapped + and not ( + self.use_ep + or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1) + ) + and self._shared_experts is not None + ) + + self._gate = gate + + @property + def shared_experts(self) -> torch.nn.Module | None: + return self._shared_experts if self.use_overlapped else None + + @property + def gate(self) -> torch.nn.Module | None: + return self._gate if self.use_overlapped else None + + @property + def is_internal_router(self) -> bool: + return self.gate is not None + + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + if not self.use_overlapped: + if self._shared_experts is not None: + shared_out = self._shared_experts(hidden_states) + + # Reduce shared expert outputs if necessary, since the MLP + # should have been created with reduce_results=False. + if ( + self.reduce_results + and self.tp_size > 1 + and self.must_reduce_shared_expert_outputs() + ): + shared_out = tensor_model_parallel_all_reduce(shared_out) + else: + shared_out = None + + fused_out = super().forward( + hidden_states=hidden_states, + router_logits=router_logits, + ) + else: + shared_out, fused_out = super().forward( + hidden_states=hidden_states, + router_logits=router_logits, + ) + return shared_out, fused_out diff --git a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py index e725a0f00363..99d4038ec381 100644 --- a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py +++ b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -29,7 +28,7 @@ def __eq__(self, other): def apply( self, - output: Optional[torch.Tensor], + output: torch.Tensor | None, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, @@ -52,7 +51,7 @@ def __eq__(self, other): def apply( self, - output: Optional[torch.Tensor], + output: torch.Tensor | None, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, @@ -84,7 +83,7 @@ def __eq__(self, other): def apply( self, - output: Optional[torch.Tensor], + output: torch.Tensor | None, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, @@ -133,7 +132,7 @@ def __eq__(self, other): def apply( self, - output: Optional[torch.Tensor], + output: torch.Tensor | None, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 9c35d7d2fe12..b8e0837162ef 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -11,9 +10,11 @@ _valid_deep_gemm, _valid_deep_gemm_shape, ) -from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts -from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import ( + get_mk_alignment_for_contiguous_layout, + is_deep_gemm_e8m0_used, +) class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -29,7 +30,7 @@ def __init__( self.allow_deep_gemm = ( allow_deep_gemm and self.quant_config.use_fp8_w8a8 - and self.block_shape == deep_gemm_block_shape() + and self.block_shape == get_mk_alignment_for_contiguous_layout() ) self.deep_gemm_expert = ( @@ -83,16 +84,14 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. @@ -101,8 +100,6 @@ def workspace_shapes( ): assert self.deep_gemm_expert is not None return self.deep_gemm_expert.workspace_shapes( - a, - aq, M, N, K, @@ -113,8 +110,6 @@ def workspace_shapes( ) else: return self.triton_expert.workspace_shapes( - a, - aq, M, N, K, @@ -134,12 +129,12 @@ def apply( topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): use_deep_gemm = self.allow_deep_gemm and ( diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index 8eb724a7435f..e305483eb17d 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -12,7 +11,6 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) -from vllm.utils import next_power_of_2 class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -52,47 +50,19 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # The workspaces for this implementation are managed by flashinfer. - # TODO(varun) : workspace1 is could be used as the output tensor. This - # is error-prone. Allow the `workspace_shapes` to return None workspaces - workspace1 = (M, K) - workspace2 = (0, 0) + workspace1 = (0,) + workspace2 = (0,) output = (M, K) - return (workspace1, workspace2, output, a.dtype) - - def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int, local_num_experts: int): - # Number of tokens in the input tensor. - num_tokens = x.shape[0] - # Factor to account for the imbalance of the experts. - # factor equals to the - # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert - # 1.0 means perfect expert distribution. - # > 1.0 means some experts have more tokens than the perfect - # distribution. - # < 1.0 does not make sense. - imbalance_factor = 1.3 - # Calculate the number of tokens per expert assuming perfect - # distribution. - num_tokens_per_expert = (num_tokens * top_k) // local_num_experts - # Apply the imbalance factor. - num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) - # And pad the number to the next power of 2. - tile_tokens_dim = next_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile as it's the range supported by the - # kernel. - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) - - return tile_tokens_dim + return (workspace1, workspace2, output) def apply( self, @@ -104,12 +74,12 @@ def apply( topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): topk = topk_ids.size(-1) @@ -153,9 +123,7 @@ def apply( "local_expert_offset": local_expert_offset, "local_num_experts": local_num_experts, "routed_scaling_factor": None, - "tile_tokens_dim": self._get_tile_tokens_dim( - x_quant, topk, local_num_experts - ), + "tile_tokens_dim": None, "routing_method_type": 1, "do_finalize": True, "output": output, diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index dddf788b62e2..0627ea50d821 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools from math import prod -from typing import Optional, Union import torch @@ -25,6 +25,7 @@ from vllm.triton_utils import tl, triton from vllm.utils import cdiv from vllm.utils.flashinfer import flashinfer_fp4_quantize +from vllm.utils.torch_utils import is_torch_equal_or_newer @triton.jit @@ -60,7 +61,7 @@ def _count_expert_num_tokens( def count_expert_num_tokens( - topk_ids: torch.Tensor, num_local_experts: int, expert_map: Optional[torch.Tensor] + topk_ids: torch.Tensor, num_local_experts: int, expert_map: torch.Tensor | None ) -> torch.Tensor: """ Count the number to tokens assigned to each expert. @@ -112,7 +113,7 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: def _nvfp4_quantize( A: torch.Tensor, - A_scale: Optional[torch.Tensor], + A_scale: torch.Tensor | None, is_sf_swizzled_layout: bool, ) -> tuple[torch.Tensor, torch.Tensor]: return flashinfer_fp4_quantize( @@ -122,9 +123,9 @@ def _nvfp4_quantize( def _fp8_quantize( A: torch.Tensor, - A_scale: Optional[torch.Tensor], + A_scale: torch.Tensor | None, per_act_token: bool, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Perform fp8 quantization on the inputs. If a block_shape @@ -148,9 +149,9 @@ def _fp8_quantize( def _int8_quantize( A: torch.Tensor, - A_scale: Optional[torch.Tensor], + A_scale: torch.Tensor | None, per_act_token: bool, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Perform int8 quantization on the inputs. If a block_shape @@ -175,9 +176,9 @@ def _int8_quantize( def _mxfp4_quantize( A: torch.Tensor, - A_scale: Optional[torch.Tensor], + A_scale: torch.Tensor | None, per_act_token_quant: bool, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, ) -> tuple[torch.Tensor, None]: assert block_shape is None # TODO: native mxfp4 is currently not integrated in vllm, @@ -191,9 +192,9 @@ def _mxfp4_quantize( def _mxfp8_e4m3_quantize( A: torch.Tensor, - A_scale: Optional[torch.Tensor], + A_scale: torch.Tensor | None, per_act_token_quant: bool, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: assert A_scale is None assert not per_act_token_quant @@ -203,9 +204,9 @@ def _mxfp8_e4m3_quantize( def _mxfp6_e3m2_quantize( A: torch.Tensor, - A_scale: Optional[torch.Tensor], + A_scale: torch.Tensor | None, per_act_token_quant: bool, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, ) -> tuple[torch.Tensor, None]: assert block_shape is None @@ -220,9 +221,9 @@ def _mxfp6_e3m2_quantize( def _mxfp6_e2m3_quantize( A: torch.Tensor, - A_scale: Optional[torch.Tensor], + A_scale: torch.Tensor | None, per_act_token_quant: bool, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, ) -> tuple[torch.Tensor, None]: assert block_shape is None @@ -237,12 +238,12 @@ def _mxfp6_e2m3_quantize( def moe_kernel_quantize_input( A: torch.Tensor, - A_scale: Optional[torch.Tensor], - quant_dtype: Union[None, torch.dtype, str], + A_scale: torch.Tensor | None, + quant_dtype: None | torch.dtype | str, per_act_token_quant: bool, - block_shape: Optional[list[int]] = None, + block_shape: list[int] | None = None, is_fp4_scale_swizzled: bool = True, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor | None]: if quant_dtype == torch.float8_e4m3fn: return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == torch.int8: @@ -273,7 +274,7 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: return m[idx, ...] -def normalize_scales_shape(scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]: +def normalize_scales_shape(scales: torch.Tensor | None) -> torch.Tensor | None: if scales is not None: if scales.numel() == 1: scales = scales.view(1, 1) @@ -283,9 +284,9 @@ def normalize_scales_shape(scales: Optional[torch.Tensor]) -> Optional[torch.Ten def normalize_batched_scales_shape( - scales: Optional[torch.Tensor], + scales: torch.Tensor | None, num_experts: int, -) -> Optional[torch.Tensor]: +) -> torch.Tensor | None: if scales is not None and scales.ndim < 3: if scales.numel() == 1: scales = scales.view(1) @@ -300,9 +301,9 @@ def normalize_batched_scales_shape( def _validate_scale_shape( a: torch.Tensor, - a_scale: Optional[torch.Tensor], + a_scale: torch.Tensor | None, per_act_token_quant: bool, - block_shape: Optional[list[int]], + block_shape: list[int] | None, ) -> None: if a_scale is None: return @@ -321,3 +322,11 @@ def _validate_scale_shape( def activation_without_mul(activation: str) -> str: return activation + "_no_mul" + + +# Torch custom ops can't deal with outputs aliasing inputs so we need to +# disable inplace for torch >= 2.9. +# See https://github.com/vllm-project/vllm/issues/26378 +@functools.cache +def disable_inplace() -> bool: + return is_torch_equal_or_newer("2.9") diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 6a49ae42ca89..65432c0fb2d4 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -2,16 +2,18 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Custom normalization layers.""" -from typing import Optional, Union - import torch import torch.nn as nn import torch.nn.functional as F import vllm.envs as envs from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.batch_invariant import ( + rms_norm_batch_invariant, + vllm_is_batch_invariant, +) from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def is_rocm_aiter_rmsnorm_enabled() -> bool: @@ -23,6 +25,8 @@ def rms_norm( ) -> torch.Tensor: from vllm import _custom_ops as ops + if vllm_is_batch_invariant(): + return rms_norm_batch_invariant(x, weight, variance_epsilon) out = torch.empty_like(x) ops.rms_norm( out, @@ -41,6 +45,10 @@ def fused_add_rms_norm( ) -> tuple[torch.Tensor, torch.Tensor]: from vllm import _custom_ops as ops + if vllm_is_batch_invariant(): + return rms_norm_batch_invariant( + x + residual, weight, variance_epsilon + ), x + residual ops.fused_add_rms_norm( x, residual, @@ -50,22 +58,6 @@ def fused_add_rms_norm( return x, residual -def poly_norm( - x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float -) -> torch.Tensor: - from vllm import _custom_ops as ops - - out = torch.empty_like(x) - ops.poly_norm( - out, - x, - weight, - bias, - variance_epsilon, - ) - return out - - def rocm_aiter_rms_norm_impl( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float ) -> torch.Tensor: @@ -159,9 +151,9 @@ def __init__( self, hidden_size: int, eps: float = 1e-6, - var_hidden_size: Optional[int] = None, + var_hidden_size: int | None = None, has_weight: bool = True, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype | None = None, ) -> None: super().__init__() @@ -170,14 +162,11 @@ def __init__( self.variance_size_override = ( None if var_hidden_size == hidden_size else var_hidden_size ) + weight_dtype = dtype or torch.get_default_dtype() self.has_weight = has_weight - if dtype is not None: - self.weight = torch.ones(hidden_size, dtype=dtype) - else: - self.weight = torch.ones(hidden_size) + self.weight = torch.ones(hidden_size, dtype=weight_dtype) if self.has_weight: self.weight = nn.Parameter(self.weight) - weight_dtype = self.weight.data.dtype if current_platform.is_rocm(): self.rocm_norm_func = dispatch_rocm_rmsnorm_func( @@ -187,52 +176,74 @@ def __init__( with_fused_add=True, dtype=weight_dtype ) - def forward_native( - self, + @staticmethod + def forward_static( x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + variance_epsilon: float, + hidden_size: int, + orig_dtype: torch.dtype, + weight: torch.Tensor | None = None, + residual: torch.Tensor | None = None, + variance_size_override: int | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """PyTorch-native implementation equivalent to forward().""" - orig_dtype = x.dtype x = x.to(torch.float32) if residual is not None: - x = x + residual.to(torch.float32) + # residual promoted f16->f32 automatically, + # otherwise Inductor eliminates the casts to and from f16, + # increasing memory usage (and complicating pattern matching) + x = x + residual residual = x.to(orig_dtype) - hidden_size = x.shape[-1] - if hidden_size != self.hidden_size: + if x.shape[-1] != hidden_size: raise ValueError( - "Expected hidden_size to be " - f"{self.hidden_size}, but found: {hidden_size}" + f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}" ) - if self.variance_size_override is None: + if variance_size_override is None: x_var = x else: - if hidden_size < self.variance_size_override: + if hidden_size < variance_size_override: raise ValueError( "Expected hidden_size to be at least " - f"{self.variance_size_override}, but found: {hidden_size}" + f"{variance_size_override}, but found: {hidden_size}" ) - x_var = x[:, :, : self.variance_size_override] + x_var = x[:, :, :variance_size_override] variance = x_var.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x * torch.rsqrt(variance + variance_epsilon) x = x.to(orig_dtype) - if self.has_weight: - x = x * self.weight + if weight is not None: + x = x * weight if residual is None: return x else: return x, residual + def forward_native( + self, + x: torch.Tensor, + residual: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + + return self.forward_static( + x, + self.variance_epsilon, + self.hidden_size, + x.dtype, + self.weight.data if self.has_weight else None, + residual, + self.variance_size_override, + ) + def forward_cuda( self, x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + residual: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if self.variance_size_override is not None: return self.forward_native(x, residual) @@ -247,8 +258,8 @@ def forward_cuda( def forward_hip( self, x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + residual: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if self.variance_size_override is not None: return self.forward_native(x, residual) @@ -263,8 +274,8 @@ def forward_hip( def forward_xpu( self, x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + residual: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if self.variance_size_override is not None: return self.forward_native(x, residual) @@ -313,12 +324,16 @@ def forward_static( weight: torch.Tensor, variance_epsilon: float, x: torch.Tensor, - residual: Optional[torch.Tensor], - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + residual: torch.Tensor | None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """PyTorch-native implementation equivalent to forward().""" orig_dtype = x.dtype if residual is not None: - x = x + residual.float() if orig_dtype == torch.float16 else x + residual + x = ( + x.float() + residual.float() + if orig_dtype == torch.float16 + else x + residual + ) residual = x x = x.float() @@ -333,16 +348,16 @@ def forward_static( def forward_native( self, x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + residual: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """PyTorch-native implementation equivalent to forward().""" return self.forward_static(self.weight.data, self.variance_epsilon, x, residual) def forward_cuda( self, x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + residual: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if torch.compiler.is_compiling(): return self.forward_native(x, residual) @@ -354,53 +369,6 @@ def forward_cuda( return self.forward_native(x, residual) -@CustomOp.register("poly_norm") -class PolyNorm(CustomOp): - """Polynomial normalization. - - Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b - where w_n is the learned weight and b is the bias. - Refer to https://arxiv.org/html/2411.03884v1 - """ - - def __init__( - self, - eps: float = 1e-6, - ) -> None: - super().__init__() - self.weight = torch.nn.Parameter(torch.ones(3) / 3) - self.bias = torch.nn.Parameter(torch.zeros(1)) - self.variance_epsilon = eps - - def _norm(self, x): - return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon) - - def forward_native( - self, - x: torch.Tensor, - ) -> torch.Tensor: - """PyTorch-native implementation equivalent to forward(). - - Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md - """ - - orig_dtype = x.dtype - x_float = x.to(torch.float32) - output = ( - self.weight[0] * self._norm(x_float**3) - + self.weight[1] * self._norm(x_float**2) - + self.weight[2] * self._norm(x_float) - + self.bias - ) - return output.to(orig_dtype) - - def forward_cuda( - self, - x: torch.Tensor, - ) -> torch.Tensor: - return poly_norm(x, self.weight, self.bias, self.variance_epsilon) - - class LayerNorm(nn.Module): """ Layer Normalization. diff --git a/vllm/model_executor/layers/lightning_attn.py b/vllm/model_executor/layers/lightning_attn.py index e874301b02c0..99853680eac6 100644 --- a/vllm/model_executor/layers/lightning_attn.py +++ b/vllm/model_executor/layers/lightning_attn.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch from einops import rearrange @@ -529,7 +528,7 @@ def lightning_attention( v: torch.Tensor, ed: torch.Tensor, block_size: int = 256, - kv_history: Optional[torch.Tensor] = None, + kv_history: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Apply lightning attention algorithm diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3881ba12faa0..dfcc601a1c53 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -3,10 +3,9 @@ import itertools from abc import abstractmethod -from typing import Any, Literal, Optional, Union +from typing import Any import torch -import torch.nn as nn from torch.nn.parameter import Parameter, UninitializedParameter from vllm.distributed import ( @@ -35,7 +34,6 @@ ) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import GiB_bytes logger = init_logger(__name__) @@ -188,7 +186,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: """Apply the weights in layer to the input tensor. Expects create_weights to have been called before on the layer.""" @@ -212,33 +210,17 @@ def create_weights( # The weights are not quantized, and they are not sharded. # The amount of memory allocated for the weights is # sum(output_partition_sizes) * input_size_per_partition. - try: - weight_loader = extra_weight_attrs.pop("weight_loader") - weight = ModelWeightParameter( - data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - except torch.cuda.OutOfMemoryError as e: - logger.error("Failed to create unquantized linear weights: %s", e) - if torch.cuda.is_available(): - logger.debug("CUDA device: %s", torch.cuda.current_device()) - logger.debug( - "Allocated: %.2f GiB", torch.cuda.memory_allocated() / GiB_bytes - ) - logger.debug( - "Reserved: %.2f GiB", torch.cuda.memory_reserved() / GiB_bytes - ) - raise RuntimeError( - "Failed to create unquantized linear weights. " - "This may be caused by insufficient memory to allocate " - "the weight." - ) from e + weight_loader = extra_weight_attrs.pop("weight_loader") + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) @@ -253,7 +235,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) @@ -277,8 +259,8 @@ def __init__( input_size: int, output_size: int, skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", *, return_bias: bool = True, @@ -296,7 +278,7 @@ def __init__( self.quant_config = quant_config self.prefix = prefix if quant_config is None: - self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod() + self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod() else: self.quant_method = quant_config.get_quant_method(self, prefix=prefix) self.return_bias = return_bias @@ -334,8 +316,8 @@ def __init__( output_size: int, bias: bool = True, skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", *, return_bias: bool = True, @@ -410,7 +392,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def forward( self, x: torch.Tensor, - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None @@ -462,9 +444,9 @@ def __init__( bias: bool = True, gather_output: bool = False, skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - output_sizes: Optional[list[int]] = None, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + output_sizes: list[int] | None = None, prefix: str = "", *, return_bias: bool = True, @@ -575,7 +557,7 @@ def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor def forward( self, input_, - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: bias = self.bias if not self.skip_bias_add else None # Matrix multiply. @@ -634,8 +616,8 @@ def __init__( bias: bool = True, gather_output: bool = False, skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", *, return_bias: bool = True, @@ -663,7 +645,7 @@ def weight_loader( self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_id: Optional[int] = None, + loaded_shard_id: int | None = None, ): # Special case for GGUF # initialize GGUF param after we know the quantize type @@ -839,7 +821,7 @@ def weight_loader_v2( self, param: BasevLLMParameter, loaded_weight: torch.Tensor, - loaded_shard_id: Optional[int] = None, + loaded_shard_id: int | None = None, ): if loaded_shard_id is None: if isinstance(param, PerTensorScaleParameter): @@ -915,11 +897,11 @@ def __init__( hidden_size: int, head_size: int, total_num_heads: int, - total_num_kv_heads: Optional[int] = None, + total_num_kv_heads: int | None = None, bias: bool = True, skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", *, return_bias: bool = True, @@ -1028,7 +1010,7 @@ def weight_loader_v2( self, param: BasevLLMParameter, loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None, + loaded_shard_id: str | None = None, ): if loaded_shard_id is None: # special case for certain models if isinstance(param, PerTensorScaleParameter): @@ -1072,7 +1054,7 @@ def weight_loader( self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None, + loaded_shard_id: str | None = None, ): # Special case for GGUF # initialize GGUF param after we know the quantize type @@ -1230,10 +1212,10 @@ def weight_loader( param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": - shard_id = self.tp_rank + shard_rank = self.tp_rank else: - shard_id = self.tp_rank // self.num_kv_head_replicas - start_idx = shard_id * shard_size + shard_rank = self.tp_rank // self.num_kv_head_replicas + start_idx = shard_rank * shard_size if not is_sharded_weight: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) @@ -1297,9 +1279,9 @@ def __init__( bias: bool = True, input_is_parallel: bool = True, skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, + params_dtype: torch.dtype | None = None, reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", *, return_bias: bool = True, @@ -1406,7 +1388,7 @@ def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor def forward( self, input_, - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: if self.input_is_parallel: input_parallel = input_ else: @@ -1440,237 +1422,3 @@ def extra_repr(self) -> str: s += f", tp_size={self.tp_size}" s += f", reduce_results={self.reduce_results}" return s - - -@CustomOp.register("qkv_cross_parallel_linear") -class QKVCrossParallelLinear(LinearBase): - """Linear layers for efficient cross-attention's QKV transformation. - - Args: - hidden_size: input hidden state size of the transformer. - head_size: size of each attention head. - total_num_heads: total number of attention query heads. - total_num_kv_heads: total number of attention key/value heads. If - None, assume total_num_kv_heads = total_num_heads. - bias: If true, add bias. - skip_bias_add: This was added to enable performance optimizations where - bias can be fused with other element-wise operations. we - skip adding bias but instead return it. - params_dtype: Data type for the parameters. - quant_config: Quantization configure. - prefix: The name of the layer in the state dict, including all parents - (e.g. model.layers.0.qkv_proj) - """ - - def __init__( - self, - hidden_size: int, - head_size: int, - total_num_heads: int, - total_num_kv_heads: Optional[int] = None, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - # input_size and output_size are not used, just for alignment - input_size = hidden_size - output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size - super().__init__( - input_size=input_size, - output_size=output_size, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - quant_config=quant_config, - prefix=prefix, - ) - - self.quant_config = quant_config - - # Empty placeholders for loading as a single module. - placeholder_size = 0 - assert self.quant_method is not None - self.quant_method.create_weights( - self, - placeholder_size, - [placeholder_size], - placeholder_size, - placeholder_size, - self.params_dtype, - weight_loader=self.weight_loader, - ) - - # Use a dictionary to avoid submodules parameters auto-registration: - # drop-in replacement for a `QKVParallelLinear` module. - self.proj = dict() - self.proj["q_proj_decoder"] = ColumnParallelLinear( - input_size=hidden_size, - output_size=total_num_heads * head_size, - bias=bias, - quant_config=quant_config, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - prefix=f"{prefix}.q_proj_decoder", - ) - - self.proj["kv_proj_encoder"] = QKVParallelLinear( - hidden_size=hidden_size, - head_size=head_size, - total_num_heads=0, - total_num_kv_heads=total_num_kv_heads, - bias=bias, - quant_config=quant_config, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - prefix=f"{prefix}.kv_proj_encoder", - ) - - # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1. - self.q_size = self.q_proj_decoder.output_size_per_partition - self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size - - if bias: - self.bias = torch.nn.Parameter() - set_weight_attrs( - self.bias, - { - "output_dim": 0, - "weight_loader": self.weight_loader_v1, - }, - ) - else: - self.bias = None - - def process_weights_after_loading(self): - for layer in self.proj.values(): - if self.quant_method is not None: - self.quant_method.process_weights_after_loading(layer) - - @property - def q_proj_decoder(self) -> ColumnParallelLinear: - layer = self.proj["q_proj_decoder"] - for name, param in self.named_parameters(): - target_param = getattr(layer, name, None) - if target_param is not None: - self.sync_weight_attrs(param, target_param, mode="q_proj_decoder") - return layer - - @property - def kv_proj_encoder(self) -> QKVParallelLinear: - layer = self.proj["kv_proj_encoder"] - for name, param in self.named_parameters(): - target_param = getattr(layer, name, None) - if target_param is not None: - self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder") - return layer - - def sync_weight_attrs( - self, - src_param: nn.Parameter, - tgt_param: nn.Parameter, - mode: Literal["q_proj_decoder", "kv_proj_encoder"], - ): - missing_attrs_dict = { - k: getattr(src_param, k) - for k in (set(vars(src_param).keys()) - set(vars(tgt_param).keys())) - } - # TODO(Isotr0py): handle bitsandbytes 8bit - use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit", False) - if missing_attrs_dict and use_bitsandbytes_4bit: - q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard( - missing_attrs_dict - ) - if mode == "q_proj_decoder": - set_weight_attrs(tgt_param, q_proj_attrs) - elif mode == "kv_proj_encoder": - set_weight_attrs(tgt_param, kv_proj_attrs) - else: - set_weight_attrs(tgt_param, missing_attrs_dict) - - def _is_same_param( - self, - src_param: torch.nn.Parameter, - map_param: torch.nn.Parameter, - ) -> bool: - """Check if two parameters are exactly pointing to same things.""" - # ignore weight_loader because it's always different - key_to_ignore = ["weight_loader", "_weight_loader"] - has_same_type_name = type(src_param) is type(map_param) - src_param_attrs = { - k: v for k, v in src_param.__dict__.items() if k not in key_to_ignore - } - map_param_attrs = { - k: v for k, v in map_param.__dict__.items() if k not in key_to_ignore - } - has_same_attrs = src_param_attrs == map_param_attrs - return has_same_type_name and has_same_attrs - - def select_proj_params( - self, - layer: nn.Module, - param: nn.Parameter, - ) -> nn.Parameter: - """ - Given the placeholder param, - return the corresponding param in the proj layers. - """ - target_param_list = [ - v for _, v in layer.named_parameters() if self._is_same_param(param, v) - ] - assert len(target_param_list) == 1 - target_param = target_param_list[0] - return target_param - - def forward( # type: ignore[override] - self, - decoder_hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - ) -> tuple[torch.Tensor, ...]: - q, _ = self.q_proj_decoder(decoder_hidden_states) - if encoder_hidden_states is None: - # Encoder KV already cached. - k = None - v = None - else: - # Prefill phase, encoder KV cached here. - kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states) - # Split kv in half - k, v = kv_enc.split(self.kv_size, dim=-1) - return q, k, v - - def weight_loader_v1( - self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None, - ): - # just like all other parameters, does not yet - # support loading bias with weight_loader_v2 - layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder - target_param = self.select_proj_params(layer, param) - shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else () - layer.weight_loader(target_param, loaded_weight, *shard_id_args) - - def weight_loader( - self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None, - ): - layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder - target_param = self.select_proj_params(layer, param) - shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else () - if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED: - layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args) - else: - layer.weight_loader(target_param, loaded_weight, *shard_id_args) - - def extra_repr(self) -> str: - s = f"in_features={self.input_size}" - s += f", q_size={self.q_size}" - s += f", kv_size={self.kv_size}" - s += f", bias={self.bias is not None}" - s += f", tp_size={get_tensor_model_parallel_world_size()}" - s += ", gather_output=False" - return s diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 3db5e0b32553..c8d57f597d1c 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -2,8 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A layer that compute logits from hidden_stats.""" -from typing import Optional - import torch from vllm.distributed import ( @@ -28,10 +26,10 @@ class LogitsProcessor(CustomOp): def __init__( self, vocab_size: int, - org_vocab_size: Optional[int] = None, + org_vocab_size: int | None = None, scale: float = 1.0, logits_as_input: bool = False, - soft_cap: Optional[float] = None, + soft_cap: float | None = None, ) -> None: """ Args: @@ -53,8 +51,8 @@ def forward( self, lm_head: VocabParallelEmbedding, hidden_states: torch.Tensor, - embedding_bias: Optional[torch.Tensor] = None, - ) -> Optional[torch.Tensor]: + embedding_bias: torch.Tensor | None = None, + ) -> torch.Tensor | None: if self.logits_as_input: logits = hidden_states else: @@ -88,8 +86,8 @@ def _get_logits( self, hidden_states: torch.Tensor, lm_head: VocabParallelEmbedding, - embedding_bias: Optional[torch.Tensor], - ) -> Optional[torch.Tensor]: + embedding_bias: torch.Tensor | None, + ) -> torch.Tensor | None: # Get the logits for the next tokens. logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias) diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index 6da62b5426bb..e68b09b4d81f 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -6,7 +6,9 @@ import torch +from vllm.config import VllmConfig from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -40,3 +42,30 @@ def mamba_type(self) -> str: def get_attn_backend(self) -> type["AttentionBackend"]: """Get the attention backend class for this Mamba layer.""" pass + + @abstractmethod + def get_state_dtype(self) -> tuple[torch.dtype, ...]: + pass + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: + if ( + vllm_config.speculative_config is not None + and vllm_config.model_config.hf_config.model_type not in ["qwen3_next"] + ): + raise NotImplementedError( + "Mamba with speculative decoding is not supported yet." + ) + mamba_block_size = vllm_config.cache_config.mamba_block_size + page_size_padded = vllm_config.cache_config.mamba_page_size_padded + return MambaSpec( + shapes=self.get_state_shape(), + dtypes=self.get_state_dtype(), + block_size=mamba_block_size, + page_size_padded=page_size_padded, + mamba_type=self.mamba_type, + num_speculative_blocks=( + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config + else 0 + ), + ) diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index 99f05e2eca0e..fd4567ee4701 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -10,7 +10,6 @@ from typing import TYPE_CHECKING import torch -import torch.distributed import torch.nn.functional as F from einops import rearrange from torch import nn @@ -35,15 +34,12 @@ MambaStateShapeCalculator, ) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend -import torch -import torch.distributed - class MiniMaxText01RMSNormTP(CustomOp): name = "MiniMaxText01RMSNormTP" @@ -87,8 +83,8 @@ def _forward( def forward( self, x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + residual: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert residual is None, "RMSNorm does not support residual connection." return self._forward(x) @@ -102,7 +98,7 @@ def jit_linear_forward_prefix( kv_caches: torch.Tensor, slope_rate: torch.Tensor, block_size: int, - layer_idx: Optional[int] = None, + layer_idx: int | None = None, **kwargs, ) -> torch.Tensor: slope_rate = slope_rate.to(torch.float32) @@ -154,9 +150,9 @@ def __init__( max_position: int, block_size: int, num_hidden_layer: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, layer_idx: int = 0, linear_layer_idx: int = 0, prefix: str = "linear_attn", diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 8ab77965ae80..a9a0c216474b 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, NamedTuple, Optional +from typing import TYPE_CHECKING, NamedTuple if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -37,7 +37,7 @@ selective_state_update, ) from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata @@ -68,8 +68,8 @@ def __init__( rms_norm_eps: float = 1e-5, activation="silu", is_lora_enabled: bool = False, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, prefix: str = "", ): super().__init__() @@ -410,7 +410,7 @@ def get_attn_backend(self) -> type["AttentionBackend"]: return Mamba1AttentionBackend - def _time_proj_bias(self) -> Optional[torch.Tensor]: + def _time_proj_bias(self) -> torch.Tensor | None: if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None: return self.dt_proj.bias.float() return None @@ -423,8 +423,8 @@ class PrefillDecodeSplit(NamedTuple): gate_d: torch.Tensor state_indices_tensor_p: torch.Tensor state_indices_tensor_d: torch.Tensor - query_start_loc_p: Optional[torch.Tensor] - has_initial_states_p: Optional[torch.Tensor] + query_start_loc_p: torch.Tensor | None + has_initial_states_p: torch.Tensor | None def split_batch_to_prefill_and_decode( @@ -432,7 +432,7 @@ def split_batch_to_prefill_and_decode( gate: torch.Tensor, state_indices_tensor: torch.Tensor, query_start_loc: torch.Tensor, - has_initial_states: Optional[torch.Tensor], + has_initial_states: torch.Tensor | None, num_prefill_tokens: int, num_decode_tokens: int, num_prefills: int, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 7589905ac927..fb45afa33dad 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -46,7 +46,7 @@ sharded_weight_loader, ) from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata # Added by the IBM Team, 2024 @@ -138,7 +138,7 @@ def forward_cuda( self, x: torch.Tensor, gate: torch.Tensor, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: input_dtype = x.dtype if not self.use_rms_norm: # Keep gate in float32 for numerical stability during silu @@ -244,9 +244,9 @@ def __init__( rms_norm_eps: float = 1e-5, activation: str = "silu", use_rms_norm: bool = True, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -474,7 +474,7 @@ def forward_native( self, hidden_states: torch.Tensor, output: torch.Tensor, - mup_vector: Optional[torch.Tensor] = None, + mup_vector: torch.Tensor | None = None, ): pass @@ -482,7 +482,7 @@ def forward( self, hidden_states: torch.Tensor, output: torch.Tensor, - mup_vector: Optional[torch.Tensor] = None, + mup_vector: torch.Tensor | None = None, ): torch.ops.vllm.mamba_mixer2( hidden_states, @@ -495,7 +495,7 @@ def forward_cuda( self, hidden_states: torch.Tensor, output: torch.Tensor, - mup_vector: Optional[torch.Tensor] = None, + mup_vector: torch.Tensor | None = None, ): forward_context = get_forward_context() # attn_metadata contains metadata necessary for the mamba2 triton @@ -904,7 +904,7 @@ def mamba_mixer2( hidden_states: torch.Tensor, output: torch.Tensor, layer_name: str, - mup_vector: Optional[torch.Tensor] = None, + mup_vector: torch.Tensor | None = None, ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] @@ -915,7 +915,7 @@ def mamba_mixer2_fake( hidden_states: torch.Tensor, output: torch.Tensor, layer_name: str, - mup_vector: Optional[torch.Tensor] = None, + mup_vector: torch.Tensor | None = None, ) -> None: return diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 21c36617a872..91a45623582d 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -1,19 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Union import torch -from vllm.config import MambaDType, ModelDType +from vllm.config.cache import MambaDType +from vllm.config.model import ModelDType from vllm.distributed import divide -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype +from vllm.utils.torch_utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + get_kv_cache_torch_dtype, +) class MambaStateDtypeCalculator: @classmethod def linear_attention_state_dtype( cls, - model_dtype: Union[ModelDType, torch.dtype], + model_dtype: ModelDType | torch.dtype, mamba_cache_dtype: MambaDType, ) -> tuple[torch.dtype, ...]: # TODO (tdoublep) requires testing @@ -25,7 +28,7 @@ def linear_attention_state_dtype( @classmethod def mamba1_state_dtype( cls, - model_dtype: Union[ModelDType, torch.dtype], + model_dtype: ModelDType | torch.dtype, mamba_cache_dtype: MambaDType, mamba_ssm_cache_dtype: MambaDType, ) -> tuple[torch.dtype, ...]: @@ -36,7 +39,7 @@ def mamba1_state_dtype( @classmethod def mamba2_state_dtype( cls, - model_dtype: Union[ModelDType, torch.dtype], + model_dtype: ModelDType | torch.dtype, mamba_cache_dtype: MambaDType, mamba_ssm_cache_dtype: MambaDType, ) -> tuple[torch.dtype, ...]: @@ -47,7 +50,7 @@ def mamba2_state_dtype( @classmethod def _mamba_state_dtype( cls, - model_dtype: Union[ModelDType, torch.dtype], + model_dtype: ModelDType | torch.dtype, mamba_cache_dtype: MambaDType, mamba_ssm_cache_dtype: MambaDType, ) -> tuple[torch.dtype, ...]: @@ -62,7 +65,7 @@ def _mamba_state_dtype( @classmethod def short_conv_state_dtype( cls, - model_dtype: Union[ModelDType, torch.dtype], + model_dtype: ModelDType | torch.dtype, mamba_cache_dtype: MambaDType, ) -> tuple[torch.dtype, ...]: conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) @@ -71,7 +74,7 @@ def short_conv_state_dtype( @classmethod def gated_delta_net_state_dtype( cls, - model_dtype: Union[ModelDType, torch.dtype], + model_dtype: ModelDType | torch.dtype, mamba_cache_dtype: MambaDType, ) -> tuple[torch.dtype, torch.dtype]: state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index ec486d3b9267..83c2c5f11e18 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -4,7 +4,6 @@ # Copyright (c) 2024, Tri Dao. # Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py -from typing import Optional, Union import numpy as np import torch @@ -469,17 +468,17 @@ def _causal_conv1d_fwd_kernel( # continuous batching def causal_conv1d_fn( x: torch.Tensor, weight: torch.Tensor, - bias: Union[torch.Tensor, None], + bias: torch.Tensor | None, conv_states: torch.Tensor, query_start_loc: torch.Tensor, - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", + cache_indices: torch.Tensor | None = None, + has_initial_state: torch.Tensor | None = None, + activation: str | None = "silu", pad_slot_id: int = PAD_SLOT_ID, - block_idx_first_scheduled_token: Optional[torch.Tensor] = None, - block_idx_last_scheduled_token: Optional[torch.Tensor] = None, - initial_state_idx: Optional[torch.Tensor] = None, - num_computed_tokens: Optional[torch.Tensor] = None, + block_idx_first_scheduled_token: torch.Tensor | None = None, + block_idx_last_scheduled_token: torch.Tensor | None = None, + initial_state_idx: torch.Tensor | None = None, + num_computed_tokens: torch.Tensor | None = None, block_size_to_align=0, metadata=None, validate_data=False, @@ -1071,15 +1070,15 @@ def causal_conv1d_update( x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - activation: Union[bool, str, None] = None, - conv_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, - query_start_loc: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, + activation: bool | str | None = None, + conv_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + query_start_loc: torch.Tensor | None = None, max_query_len: int = -1, pad_slot_id: int = PAD_SLOT_ID, - block_idx_last_scheduled_token: Optional[torch.Tensor] = None, - initial_state_idx: Optional[torch.Tensor] = None, + block_idx_last_scheduled_token: torch.Tensor | None = None, + initial_state_idx: torch.Tensor | None = None, validate_data=False, ): """ diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index 32273d137eca..04efa8a8b373 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -27,7 +27,7 @@ causal_conv1d_fn, causal_conv1d_update, ) -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata @@ -38,8 +38,8 @@ def __init__( config, dim: int, layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, prefix: str = "", ): super().__init__() diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index b8e99226d13e..34f05f2ee962 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional import torch -from vllm.attention import Attention +from vllm.attention.layer import MLAAttention from vllm.config import CacheConfig from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization import QuantizationConfig @@ -19,19 +18,20 @@ class MLAModules: kv_b_proj: torch.nn.Module rotary_emb: torch.nn.Module o_proj: torch.nn.Module - fused_qkv_a_proj: Optional[torch.nn.Module] - kv_a_proj_with_mqa: Optional[torch.nn.Module] - q_a_layernorm: Optional[torch.nn.Module] - q_b_proj: Optional[torch.nn.Module] - q_proj: Optional[torch.nn.Module] - indexer: Optional[torch.nn.Module] + fused_qkv_a_proj: torch.nn.Module | None + kv_a_proj_with_mqa: torch.nn.Module | None + q_a_layernorm: torch.nn.Module | None + q_b_proj: torch.nn.Module | None + q_proj: torch.nn.Module | None + indexer: torch.nn.Module | None is_sparse: bool - topk_indices_buffer: Optional[torch.Tensor] + topk_indices_buffer: torch.Tensor | None @CustomOp.register("multi_head_latent_attention") -class MultiHeadLatentAttention(CustomOp): - """MLA layer registered as CustomOp. +class MultiHeadLatentAttentionWrapper(CustomOp): + """MLA layer registered as CustomOp to allow OOT backends to add + custom implementations of the outer MLA layer (including rope & o_proj). Note that currently MLA ignores the enable/disable mechanism of CustomOp because there is only one in-tree implementation in forward_native. TODO: implement this with a new PluggableLayer mechanism. @@ -54,11 +54,11 @@ def __init__( qk_nope_head_dim: int, qk_rope_head_dim: int, v_head_dim: int, - q_lora_rank: Optional[int], + q_lora_rank: int | None, kv_lora_rank: int, mla_modules: MLAModules, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -87,30 +87,19 @@ def __init__( self.topk_tokens = self.indexer.topk_tokens self.topk_indices_buffer = mla_modules.topk_indices_buffer - # In the MLA backend, kv_cache includes both k_c and - # pe (i.e. decoupled position embeddings). In particular, - # the concat_and_cache_mla op requires - # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) - # i.e. - # kv_lora_rank + qk_rope_head_dim == head_size - self.mla_attn = Attention( + self.mla_attn = MLAAttention( num_heads=self.num_heads, - head_size=self.kv_lora_rank + self.qk_rope_head_dim, scale=scale, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_mla=True, - use_sparse=mla_modules.is_sparse, - # MLA Args - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, qk_nope_head_dim=self.qk_nope_head_dim, qk_rope_head_dim=self.qk_rope_head_dim, - qk_head_dim=self.qk_head_dim, v_head_dim=self.v_head_dim, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", kv_b_proj=self.kv_b_proj, + use_sparse=self.is_sparse, indexer=self.indexer, ) @@ -171,6 +160,7 @@ def forward_native( k_pe, output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim), ) + return self.o_proj(attn_out)[0] def forward_cuda(self, *args, **kwargs): diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 979939ebc468..145f18f23566 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from collections.abc import Mapping, Set +from collections.abc import Callable, Mapping, Set from dataclasses import dataclass from enum import IntEnum from itertools import groupby -from typing import Callable, Optional, TypeVar, Union +from typing import TypeVar import torch import torch.nn as nn @@ -17,15 +17,15 @@ from vllm.model_executor.models.adapters import _load_st_projector from vllm.pooling_params import PoolingParams from vllm.tasks import PoolingTask -from vllm.utils import resolve_obj_by_qualname +from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.v1.outputs import PoolerOutput from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata logger = init_logger(__name__) PoolingFn = Callable[ - [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata], - Union[torch.Tensor, list[torch.Tensor]], + [torch.Tensor | list[torch.Tensor], PoolingMetadata], + torch.Tensor | list[torch.Tensor], ] ClassifierFn = Callable[[torch.Tensor], torch.Tensor] @@ -64,68 +64,8 @@ def apply(self, params: PoolingParams) -> None: params.requires_token_ids = self.requires_token_ids -class Pooler(nn.Module, ABC): - """The interface required for all poolers used in pooling models in vLLM.""" - - @staticmethod - def for_encode(pooler_config: PoolerConfig): - if pooler_config.pooling_type == "STEP": - return StepPooler() - - resolved_config = ResolvedPoolingConfig( - task="encode", pooling_type=PoolingType.ALL - ) - - return SimplePooler.from_config(resolved_config) - - @staticmethod - def for_embed(pooler_config: PoolerConfig): - resolved_config = ResolvedPoolingConfig.from_config( - task="embed", - pooler_config=pooler_config, - ) - - return SimplePooler.from_config(resolved_config) - - @staticmethod - def for_classify( - pooler_config: PoolerConfig, - classifier: Optional[ClassifierFn], - ): - resolved_config = ResolvedPoolingConfig.from_config( - task="classify", - pooler_config=pooler_config, - ) - - pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type) - - return ClassifierPooler( - pooling=pooling, - classifier=classifier, - ) - - @abstractmethod - def get_supported_tasks(self) -> Set[PoolingTask]: - """Determine which pooling tasks are supported.""" - raise NotImplementedError - - def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: - """ - Construct the updated pooling parameters to use for a supported task. - """ - return PoolingParamsUpdate() - - @abstractmethod - def forward( - self, - hidden_states: Union[list[torch.Tensor], torch.Tensor], - pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - raise NotImplementedError - - def get_prompt_lens( - hidden_states: Union[torch.Tensor, list[torch.Tensor]], + hidden_states: torch.Tensor | list[torch.Tensor], pooling_metadata: PoolingMetadata, ) -> torch.Tensor: return pooling_metadata.prompt_lens @@ -174,7 +114,7 @@ def get_classification_activation_function(config: PretrainedConfig): def get_cross_encoder_activation_function(config: PretrainedConfig): - function_name: Optional[str] = None + function_name: str | None = None if ( hasattr(config, "sentence_transformers") and "activation_fn" in config.sentence_transformers @@ -223,27 +163,27 @@ def forward_all( self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, - ) -> Union[list[torch.Tensor], torch.Tensor]: + ) -> list[torch.Tensor] | torch.Tensor: raise NotImplementedError def forward( self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> Union[list[torch.Tensor], torch.Tensor]: + ) -> list[torch.Tensor] | torch.Tensor: pooling_cursor = pooling_metadata.pooling_cursor return self.forward_all(hidden_states, pooling_cursor) class CLSPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: - return {"encode", "embed", "classify", "score"} + return {"token_embed", "token_classify", "embed", "classify", "score"} def forward_all( self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, - ) -> Union[list[torch.Tensor], torch.Tensor]: + ) -> list[torch.Tensor] | torch.Tensor: assert not pooling_cursor.is_partial_prefill(), ( "partial prefill not supported with CLS pooling" ) @@ -253,25 +193,25 @@ def forward_all( class LastPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: - return {"encode", "embed", "classify", "score"} + return {"token_embed", "token_classify", "embed", "classify", "score"} def forward_all( self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, - ) -> Union[list[torch.Tensor], torch.Tensor]: + ) -> list[torch.Tensor] | torch.Tensor: return hidden_states[pooling_cursor.last_token_indices_gpu] class AllPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: - return {"encode"} + return {"token_embed", "token_classify"} def forward_all( self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, - ) -> Union[list[torch.Tensor], torch.Tensor]: + ) -> list[torch.Tensor] | torch.Tensor: assert not pooling_cursor.is_partial_prefill(), ( "partial prefill not supported with ALL pooling" ) @@ -284,13 +224,13 @@ def forward_all( class MeanPool(PoolingMethod): def get_supported_tasks(self) -> Set[PoolingTask]: - return {"encode", "embed", "classify", "score"} + return {"token_embed", "token_classify", "embed", "classify", "score"} def forward_all( self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, - ) -> Union[list[torch.Tensor], torch.Tensor]: + ) -> list[torch.Tensor] | torch.Tensor: assert not pooling_cursor.is_partial_prefill(), ( "partial prefill not supported with MEAN pooling" ) @@ -398,6 +338,94 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: return self.fn(pooled_data) +class Pooler(nn.Module, ABC): + """The interface required for all poolers used in pooling models in vLLM.""" + + @staticmethod + def for_token_embed(pooler_config: PoolerConfig): + head = TokenEmbeddingPoolerHead() + + if pooler_config.pooling_type == "STEP": + return StepPooler(head=head) + + return AllPooler(head=head) + + @staticmethod + def for_token_classify( + pooler_config: PoolerConfig, + classifier: ClassifierFn | None = None, + act_fn: PoolerActivation | str | None = None, + ): + head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn) + + if pooler_config.pooling_type == "STEP": + return StepPooler(head=head) + + return AllPooler(head=head) + + @staticmethod + def for_embed(pooler_config: PoolerConfig): + resolved_config = ResolvedPoolingConfig.from_config( + task="embed", + pooler_config=pooler_config, + ) + + pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type) + head = EmbeddingPoolerHead() + + return SimplePooler(pooling=pooling, head=head) + + @staticmethod + def for_classify( + pooler_config: PoolerConfig, + classifier: ClassifierFn | None, + act_fn: PoolerActivation | str | None = None, + ): + resolved_config = ResolvedPoolingConfig.from_config( + task="classify", + pooler_config=pooler_config, + ) + + pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type) + + return ClassifierPooler( + pooling=pooling, + classifier=classifier, + act_fn=act_fn, + ) + + @abstractmethod + def get_supported_tasks(self) -> Set[PoolingTask]: + """Determine which pooling tasks are supported.""" + raise NotImplementedError + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + """ + Construct the updated pooling parameters to use for a supported task. + """ + return PoolingParamsUpdate() + + @abstractmethod + def forward( + self, + hidden_states: list[torch.Tensor] | torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + raise NotImplementedError + + +class DummyPooler(Pooler): + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"plugin", "score"} + + def forward( + self, + hidden_states: list[torch.Tensor] | torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + return hidden_states + + class PoolerHead(nn.Module): def __init__(self, activation: PoolerActivation) -> None: super().__init__() @@ -405,7 +433,7 @@ def __init__(self, activation: PoolerActivation) -> None: def forward( self, - pooled_data: Union[list[torch.Tensor], torch.Tensor], + pooled_data: list[torch.Tensor] | torch.Tensor, pooling_metadata: PoolingMetadata, ): return self.activation(pooled_data) @@ -416,16 +444,15 @@ def __init__(self) -> None: super().__init__(activation=PoolerNormalize()) # Load ST projector if available - vllm_config = get_current_vllm_config() - self.projector: Optional[nn.Module] = ( + self.projector: nn.Module | None = ( _load_st_projector(vllm_config.model_config) if vllm_config else None ) self.head_dtype = vllm_config.model_config.head_dtype def forward( self, - pooled_data: Union[list[torch.Tensor], torch.Tensor], + pooled_data: list[torch.Tensor] | torch.Tensor, pooling_metadata: PoolingMetadata, ): if isinstance(pooled_data, list): @@ -471,39 +498,6 @@ def forward( return pooled_data -class RewardPoolerHead(PoolerHead): - def __init__(self) -> None: - super().__init__(activation=PoolerClassify(static_num_labels=False)) - - vllm_config = get_current_vllm_config() - self.head_dtype = vllm_config.model_config.head_dtype - - def forward( - self, - pooled_data: Union[list[torch.Tensor], torch.Tensor], - pooling_metadata: PoolingMetadata, - ): - if isinstance(pooled_data, list): - pooled_data = [p.to(self.head_dtype) for p in pooled_data] - else: - pooled_data = pooled_data.to(self.head_dtype) - - pooling_params = get_pooling_params(pooling_metadata) - - # for softmax - flags = [p.softmax for p in pooling_params] - if len(set(flags)) == 1: - if flags[0]: - pooled_data = self.activation(pooled_data) - else: - pooled_data = [ - self.activation(vecs) if f else vecs - for vecs, f in zip(pooled_data, flags) - ] - - return pooled_data - - class SimplePooler(Pooler): """A layer that pools specific information from hidden states. @@ -513,20 +507,6 @@ class SimplePooler(Pooler): 3. Returns structured results as `PoolerOutput`. """ - @classmethod - def from_config( - cls, - pooler_config: ResolvedPoolingConfig, - ) -> "SimplePooler": - pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type) - if pooler_config.task == "embed": - head = EmbeddingPoolerHead() - elif pooler_config.task == "encode": - head = RewardPoolerHead() - else: - raise NotImplementedError(f"Unknown task: {pooler_config.task}") - return cls(pooling, head) - def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None: super().__init__() @@ -541,7 +521,7 @@ def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: def forward( self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], + hidden_states: torch.Tensor | list[torch.Tensor], pooling_metadata: PoolingMetadata, ) -> PoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) @@ -549,58 +529,6 @@ def forward( return pooled_data -class StepPooler(Pooler): - def __init__( - self, - ) -> None: - super().__init__() - - self.pooling = AllPool() - self.head = RewardPoolerHead() - - def extract_states( - self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, - ) -> Union[list[torch.Tensor], torch.Tensor]: - pooled_data_lst = self.pooling(hidden_states, pooling_metadata) - prompt_token_ids = get_prompt_token_ids(pooling_metadata) - - pooled_data = list[torch.Tensor]() - - pooling_params = get_pooling_params(pooling_metadata) - - for data, token_id, pooling_param in zip( - pooled_data_lst, prompt_token_ids, pooling_params - ): - step_tag_id = pooling_param.step_tag_id - returned_token_ids = pooling_param.returned_token_ids - - if returned_token_ids is not None and len(returned_token_ids) > 0: - data = data[:, returned_token_ids] - - if step_tag_id is not None: - data = data[token_id == step_tag_id] - pooled_data.append(data) - - return pooled_data - - def get_supported_tasks(self) -> Set[PoolingTask]: - return {"encode"} - - def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: - return PoolingParamsUpdate(requires_token_ids=True) - - def forward( - self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - pooled_data = self.extract_states(hidden_states, pooling_metadata) - pooled_data = self.head(pooled_data, pooling_metadata) - return pooled_data - - class ClassifierPooler(Pooler): """A pooling layer for classification tasks. @@ -611,27 +539,47 @@ class ClassifierPooler(Pooler): """ @staticmethod - def act_fn_for_seq_cls(config: ModelConfig): - return get_classification_activation_function(config.hf_config) + def act_fn_for_seq_cls(model_config: ModelConfig): + return get_classification_activation_function(model_config.hf_config) + + @staticmethod + def act_fn_for_cross_encoder(model_config: ModelConfig): + return get_cross_encoder_activation_function(model_config.hf_config) @staticmethod - def act_fn_for_cross_encoder(config: ModelConfig): - return get_cross_encoder_activation_function(config.hf_config) + def resolve_act_fn( + model_config: ModelConfig, + static_num_labels: bool = True, + act_fn: PoolerActivation | str | None = None, + ): + if isinstance(act_fn, str): + if act_fn == "classify": + return ClassifierPooler.act_fn_for_seq_cls(model_config) + elif act_fn == "score": + return ClassifierPooler.act_fn_for_cross_encoder(model_config) + else: + raise ValueError(f"act_fn [{act_fn=}] not supported.") + elif act_fn is None: + return PoolerClassify(static_num_labels=static_num_labels) + else: + assert callable(act_fn) + return act_fn def __init__( self, pooling: PoolingFn, - classifier: Optional[ClassifierFn], - act_fn: Optional[PoolerActivation] = None, + classifier: ClassifierFn | None, + act_fn: PoolerActivation | str | None = None, ) -> None: super().__init__() vllm_config = get_current_vllm_config() - self.pooling = pooling self.classifier = classifier - self.act_fn = act_fn or PoolerClassify() - self.logit_bias: Optional[float] = ( + self.act_fn = self.resolve_act_fn( + vllm_config.model_config, static_num_labels=True, act_fn=act_fn + ) + self.logit_bias: float | None = ( vllm_config.model_config.pooler_config.logit_bias ) self.head_dtype = vllm_config.model_config.head_dtype @@ -641,7 +589,7 @@ def get_supported_tasks(self) -> Set[PoolingTask]: def forward( self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], + hidden_states: torch.Tensor | list[torch.Tensor], pooling_metadata: PoolingMetadata, ) -> PoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) @@ -672,6 +620,150 @@ def forward( return scores +class TokenEmbeddingPoolerHead(EmbeddingPoolerHead): + def forward( + self, pooled_data: torch.Tensor, pooling_param: PoolingParams + ) -> torch.Tensor: + pooled_data = pooled_data.to(self.head_dtype) + # pooled_data shape: [n_tokens, hidden_dimension] + + # Apply ST projector + if self.projector is not None: + pooled_data = self.projector(pooled_data) + # pooled_data shape: [n_tokens, embedding_dimension] + + # for matryoshka representation + pooled_data = pooled_data[..., : pooling_param.dimensions] + + # for normalize + if pooling_param.normalize: + pooled_data = self.activation(pooled_data) + + # pooled_data shape: [n_tokens, embedding_dimension] + return pooled_data + + +class TokenClassifierPoolerHead(nn.Module): + def __init__( + self, + classifier: ClassifierFn | None, + act_fn: PoolerActivation | str | None = None, + ) -> None: + super().__init__() + vllm_config = get_current_vllm_config() + + self.classifier = classifier + self.act_fn = ClassifierPooler.resolve_act_fn( + vllm_config.model_config, static_num_labels=False, act_fn=act_fn + ) + self.logit_bias: float | None = ( + vllm_config.model_config.pooler_config.logit_bias + ) + self.head_dtype = vllm_config.model_config.head_dtype + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"token_classify"} + + def forward( + self, + hidden_states: torch.Tensor, + pooling_param: PoolingParams, + ) -> torch.Tensor: + hidden_states = hidden_states.to(self.head_dtype) + # hidden_states shape: [n_token, hidden_size] + + if self.classifier is not None: + scores = self.classifier(hidden_states) + else: + scores = hidden_states + # scores shape: [n_token, num_labels] + + if self.logit_bias is not None: + scores -= self.logit_bias + + if pooling_param.activation: + scores = self.act_fn(scores) + + # scores shape: [n_token, num_labels] + return scores + + +class AllPooler(Pooler): + def __init__(self, head: nn.Module | PoolerHead) -> None: + super().__init__() + + self.pooling = AllPool() + self.head = head + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"token_embed", "token_classify"} + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooled_data = self.pooling(hidden_states, pooling_metadata) + pooling_params = get_pooling_params(pooling_metadata) + assert len(pooled_data) == len(pooling_params) + + pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)] + return pooled_data + + +class StepPooler(Pooler): + def __init__(self, head: nn.Module | PoolerHead) -> None: + super().__init__() + + self.pooling = AllPool() + self.head = head + + def extract_states( + self, + hidden_states: torch.Tensor | list[torch.Tensor], + pooling_metadata: PoolingMetadata, + ) -> torch.Tensor | list[torch.Tensor]: + pooled_data_lst = self.pooling(hidden_states, pooling_metadata) + prompt_token_ids = get_prompt_token_ids(pooling_metadata) + + pooled_data = list[torch.Tensor]() + + pooling_params = get_pooling_params(pooling_metadata) + + for data, token_id, pooling_param in zip( + pooled_data_lst, prompt_token_ids, pooling_params + ): + step_tag_id = pooling_param.step_tag_id + returned_token_ids = pooling_param.returned_token_ids + + if returned_token_ids is not None and len(returned_token_ids) > 0: + data = data[:, returned_token_ids] + + if step_tag_id is not None: + data = data[token_id == step_tag_id] + pooled_data.append(data) + + return pooled_data + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"token_embed", "token_classify"} + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return PoolingParamsUpdate(requires_token_ids=True) + + def forward( + self, + hidden_states: torch.Tensor | list[torch.Tensor], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooled_data = self.extract_states(hidden_states, pooling_metadata) + pooling_params = get_pooling_params(pooling_metadata) + assert len(pooled_data) == len(pooling_params) + + pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)] + return pooled_data + + class DispatchPooler(Pooler): """Dispatches calls to a sub-pooler based on the pooling task.""" @@ -695,7 +787,7 @@ def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: def forward( self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], + hidden_states: torch.Tensor | list[torch.Tensor], pooling_metadata: PoolingMetadata, ) -> PoolerOutput: poolers_by_task = self.poolers_by_task diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 9d1c66e56e91..b92fb8d266b7 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -12,6 +12,7 @@ "fp8", "ptpc_fp8", "fbgemm_fp8", + "fp_quant", "modelopt", "modelopt_fp4", "bitblas", @@ -102,6 +103,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .experts_int8 import ExpertsInt8Config from .fbgemm_fp8 import FBGEMMFp8Config from .fp8 import Fp8Config + from .fp_quant import FPQuantConfig from .gguf import GGUFConfig from .gptq import GPTQConfig from .gptq_bitblas import GPTQBitBLASConfig @@ -125,6 +127,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "tpu_int8": Int8TpuConfig, "fp8": Fp8Config, "fbgemm_fp8": FBGEMMFp8Config, + "fp_quant": FPQuantConfig, "modelopt": ModelOptFp8Config, "modelopt_fp4": ModelOptNvFp4Config, "bitblas": BitBLASConfig, diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py index b7ebc6f272db..f1943d461187 100644 --- a/vllm/model_executor/layers/quantization/auto_round.py +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from fractions import Fraction -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any +import regex as re import torch from vllm.logger import init_logger @@ -46,8 +47,8 @@ def __init__( group_size: int, sym: bool = True, packing_format: str = "auto_round:auto_gptq", - block_name_to_quantize: Optional[Union[str, list[str]]] = None, - extra_config: Optional[dict[str, Any]] = None, + block_name_to_quantize: str | list[str] | None = None, + extra_config: dict[str, Any] | None = None, data_type: str = "int", backend: str = "auto", ) -> None: @@ -128,11 +129,44 @@ def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig": def get_layer_config(self, layer, layer_name: str): def get_config(name: str, quantized: bool = True): - cfg = self.extra_config.get(name, {}) if self.extra_config else {} + if not self.extra_config: + return ( + self.weight_bits if quantized else 16, + self.group_size if quantized else -1, + self.sym if quantized else True, + ) + + # exact match first + if name in self.extra_config: + cfg = self.extra_config[name] + return ( + cfg.get("bits", self.weight_bits if quantized else 16), + cfg.get("group_size", self.group_size if quantized else -1), + cfg.get("sym", self.sym if quantized else True), + ) + + REGEX_SPECIAL_CHARS = set(r"*+?^$()[]{}|\\") + for pattern, cfg in self.extra_config.items(): + if not isinstance(pattern, str) or not any( + c in REGEX_SPECIAL_CHARS for c in pattern + ): + continue + + try: + if re.search(re.compile(pattern), name) is not None: + return ( + cfg.get("bits", self.weight_bits if quantized else 16), + cfg.get("group_size", self.group_size if quantized else -1), + cfg.get("sym", self.sym if quantized else True), + ) + except re.error: + # Invalid regex, ignore. + continue + return ( - cfg.get("bits", self.weight_bits if quantized else 16), - cfg.get("group_size", self.group_size if quantized else -1), - cfg.get("sym", self.sym if quantized else True), + self.weight_bits if quantized else 16, + self.group_size if quantized else -1, + self.sym if quantized else True, ) # 1. Exact match from config @@ -176,7 +210,7 @@ def get_config(name: str, quantized: bool = True): f"consistent quant config for {sub_names}" ) - # 5. Fallback + # 5. Fallback or try a regular expression match return get_config(layer_name, quantized) def check_quantized(self, weight_bits: int) -> bool: @@ -402,6 +436,12 @@ def apply_ipex_quant_layer(self, layer, prefix: str): return None def get_quant_method(self, layer: torch.nn.Module, prefix: str): + if prefix and self.extra_config: + for layer_name in self.extra_config: + if ( + layer_name == prefix or layer_name == f"model.{prefix}" + ) and self.extra_config[layer_name].get("bits", 16) >= 16: + return UnquantizedLinearMethod() if ( current_platform.is_cpu() or current_platform.is_xpu() diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index d4f667564848..0cf8b69f9f6b 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union import torch +from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -13,12 +14,17 @@ LinearMethodBase, UnquantizedLinearMethod, ) -from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter +from vllm.transformers_utils.config import get_safetensors_params_metadata + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization import QuantizationMethods + from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) @@ -34,7 +40,7 @@ def __init__( weight_bits: int, group_size: int, zero_point: bool, - modules_to_not_convert: Optional[list[str]] = None, + modules_to_not_convert: list[str] | None = None, ) -> None: super().__init__() self.weight_bits = weight_bits @@ -57,7 +63,7 @@ def __repr__(self) -> str: f"modules_to_not_convert={self.modules_to_not_convert})" ) - def get_name(self) -> QuantizationMethods: + def get_name(self) -> "QuantizationMethods": return "awq" def get_supported_act_dtypes(self) -> list[torch.dtype]: @@ -88,9 +94,14 @@ def from_config(cls, config: dict[str, Any]) -> "AWQConfig": def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]: + ) -> Union["LinearMethodBase", "QuantizeMethodBase"] | None: if isinstance(layer, LinearBase): - if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + if is_layer_skipped( + prefix, + self.modules_to_not_convert, + self.packed_modules_mapping, + skip_with_substr=True, + ): return UnquantizedLinearMethod() return AWQLinearMethod(self) elif isinstance(layer, FusedMoE): @@ -128,9 +139,26 @@ def get_quant_method( return AWQMoEMethod(awq_marlin_config, layer.moe_config) return None + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.modules_to_not_convert: + self.modules_to_not_convert = hf_to_vllm_mapper.apply_list( + self.modules_to_not_convert + ) + + def maybe_update_config(self, model_name: str, revision: str | None = None): + if self.modules_to_not_convert: + return -def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]): - return any(module_name in prefix for module_name in modules_to_not_convert) + unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32] + metadata = get_safetensors_params_metadata(model_name, revision=revision) + layers = {param_name.rsplit(".", 1)[0] for param_name in metadata} + quant_layers: set[str] = { + param_name.rsplit(".", 1)[0] + for param_name, info in metadata.items() + if (dtype := info.get("dtype", None)) + and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes + } + self.modules_to_not_convert = list(layers - quant_layers) class AWQLinearMethod(LinearMethodBase): @@ -227,7 +255,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: qweight = layer.qweight scales = layer.scales diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 5d142387d4d9..daf7422963f3 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Optional import torch +from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from torch.nn import Parameter import vllm.model_executor.layers.fused_moe # noqa @@ -13,6 +15,7 @@ FusedMoEConfig, FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, @@ -25,8 +28,7 @@ UnquantizedLinearMethod, set_weight_attrs, ) -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.awq import AWQConfig, is_layer_skipped_awq +from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -47,10 +49,16 @@ verify_marlin_supported, verify_marlin_supports_shape, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +from vllm.transformers_utils.config import get_safetensors_params_metadata + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization import QuantizationMethods + from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) @@ -70,7 +78,7 @@ def __init__( group_size: int, zero_point: bool, lm_head_quantized: bool, - modules_to_not_convert: Optional[list[str]], + modules_to_not_convert: list[str] | None, full_config: dict[str, Any], ) -> None: super().__init__() @@ -104,7 +112,7 @@ def __repr__(self) -> str: ) @classmethod - def get_name(cls) -> QuantizationMethods: + def get_name(cls) -> "QuantizationMethods": return "awq_marlin" @classmethod @@ -140,7 +148,7 @@ def from_config(cls, config: dict[str, Any]) -> "AWQMarlinConfig": @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant - ) -> Optional[QuantizationMethods]: + ) -> Optional["QuantizationMethods"]: can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg) is_valid_user_quant = ( user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin" @@ -169,7 +177,12 @@ def get_quant_method( if isinstance(layer, LinearBase) or ( isinstance(layer, ParallelLMHead) and self.lm_head_quantized ): - if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + if is_layer_skipped( + prefix, + self.modules_to_not_convert, + self.packed_modules_mapping, + skip_with_substr=True, + ): return UnquantizedLinearMethod() # Check if the layer is supported by AWQMarlin. if not check_marlin_supports_layer(layer, self.group_size): @@ -184,8 +197,10 @@ def get_quant_method( elif isinstance(layer, FusedMoE): from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config - if is_layer_skipped_awq( - prefix, getattr(self, "modules_to_not_convert", []) + if is_layer_skipped( + prefix, + getattr(self, "modules_to_not_convert", []), + skip_with_substr=True, ): return UnquantizedFusedMoEMethod(layer.moe_config) if not check_moe_marlin_supports_layer(layer, self.group_size): @@ -224,6 +239,27 @@ def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]): quant_type=cls.TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point ) + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.modules_to_not_convert: + self.modules_to_not_convert = hf_to_vllm_mapper.apply_list( + self.modules_to_not_convert + ) + + def maybe_update_config(self, model_name: str, revision: str | None = None): + if self.modules_to_not_convert: + return + + unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32] + metadata = get_safetensors_params_metadata(model_name, revision=revision) + layers = {param_name.rsplit(".", 1)[0] for param_name in metadata} + quant_layers: set[str] = { + param_name.rsplit(".", 1)[0] + for param_name, info in metadata.items() + if (dtype := info.get("dtype", None)) + and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes + } + self.modules_to_not_convert = list(layers - quant_layers) + class AWQMarlinLinearMethod(LinearMethodBase): """Linear method for AWQ Marlin. @@ -360,7 +396,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: return apply_awq_marlin_linear( input=x, @@ -555,7 +591,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: return None def apply( @@ -566,21 +602,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.fused_experts is None if enable_eplb: @@ -603,7 +639,7 @@ def apply( indices_type=self.topk_indices_dtype, ) - return torch.ops.vllm.fused_marlin_moe( + return fused_marlin_moe( x, layer.w13_qweight, layer.w2_qweight, diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 26f5e8bb6c7d..c8a8424eb5c8 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -3,7 +3,7 @@ import inspect from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import torch from torch import nn @@ -105,7 +105,7 @@ def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig": @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant - ) -> Optional[QuantizationMethods]: + ) -> QuantizationMethods | None: """ Detects if this quantization method can support a given checkpoint format by overriding the user specified quantization method -- @@ -135,7 +135,7 @@ def get_from_keys_or(config: dict[str, Any], keys: list[str], default: Any) -> A @abstractmethod def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional[QuantizeMethodBase]: + ) -> QuantizeMethodBase | None: """Get the quantize method to use for the quantized layer. Args: @@ -147,7 +147,7 @@ def get_quant_method( """ raise NotImplementedError - def get_cache_scale(self, name: str) -> Optional[str]: + def get_cache_scale(self, name: str) -> str | None: return None def apply_vllm_mapper( # noqa: B027 diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index d2e0582be197..be15f20cac21 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -45,10 +45,10 @@ class BitBLASConfig(QuantizationConfig): def __init__( self, weight_bits: int, - group_size: Optional[int], - desc_act: Optional[bool], - is_sym: Optional[bool], - quant_method: Optional[str], + group_size: int | None, + desc_act: bool | None, + is_sym: bool | None, + quant_method: str | None, lm_head_quantized: bool, ) -> None: try: @@ -160,7 +160,7 @@ def from_config(cls, config: dict[str, Any]) -> "BitBLASConfig": @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant - ) -> Optional[QuantizationMethods]: + ) -> QuantizationMethods | None: # compat: autogptq >=0.8.0 use checkpoint_format: str # compat: autogptq <=0.7.1 is_bitblas_format: bool is_bitblas_format = hf_quant_cfg.get( @@ -469,7 +469,7 @@ def apply_gptq( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: qweight = layer.qweight scales = layer.scales diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 80ed121bd85b..ccd9b311cc93 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Union import torch from packaging import version @@ -22,7 +23,7 @@ QuantizationMethods, ) from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op class BitsAndBytesConfig(QuantizationConfig): @@ -41,7 +42,7 @@ def __init__( bnb_4bit_use_double_quant: bool = False, llm_int8_enable_fp32_cpu_offload: bool = False, llm_int8_has_fp16_weight: bool = False, - llm_int8_skip_modules: Optional[list[str]] = None, + llm_int8_skip_modules: list[str] | None = None, llm_int8_threshold: float = 6.0, ) -> None: super().__init__() @@ -138,7 +139,7 @@ def get_safe_value(config, keys, default_value=None): def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional[Union["LinearMethodBase", "BitsAndBytesMoEMethod"]]: + ) -> Union["LinearMethodBase", "BitsAndBytesMoEMethod"] | None: if isinstance(layer, LinearBase): if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules): return UnquantizedLinearMethod() @@ -268,7 +269,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: if self.quant_config.load_in_8bit: return self._apply_8bit_weight(layer, x, bias) @@ -279,7 +280,7 @@ def _apply_8bit_weight( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: # only load the bitsandbytes module when needed from bitsandbytes import MatmulLtState, matmul @@ -359,7 +360,7 @@ def _apply_4bit_weight( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: original_type = x.dtype original_shape = x.shape @@ -489,7 +490,7 @@ def create_weights( def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: return None def apply( @@ -500,21 +501,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts assert self.fused_experts is None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index e89d002078ac..6c7d4cd7bd9a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -71,7 +71,7 @@ __all__ = ["CompressedTensorsLinearMethod"] SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config" -QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]] +QUANTIZATION_SCHEME_MAP_TYPE = dict[str, dict[str, QuantizationArgs] | None] class CompressedTensorsConfig(QuantizationConfig): @@ -82,9 +82,9 @@ def __init__( quant_format: str, sparsity_scheme_map: dict[str, SparsityCompressionConfig], sparsity_ignore_list: list[str], - kv_cache_scheme: Optional[dict[str, Any]] = None, - config: Optional[dict[str, Any]] = None, - transform_config: Optional[dict[str, Any]] = None, + kv_cache_scheme: dict[str, Any] | None = None, + config: dict[str, Any] | None = None, + transform_config: dict[str, Any] | None = None, ): super().__init__() self.ignore = ignore @@ -310,7 +310,7 @@ def _is_fp4a4_nvfp4( ) is_float_type = ( weight_quant.type == QuantizationType.FLOAT - and input_quant.type == QuantizationType.FLOAT.value + and input_quant.type == QuantizationType.FLOAT ) is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4 @@ -524,7 +524,7 @@ def _get_scheme_from_parts( self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs, - format: Optional[str] = None, + format: str | None = None, ) -> "CompressedTensorsScheme": # use the per-layer format if defined, otherwise, use global format format = format if format is not None else self.quant_format @@ -631,7 +631,7 @@ def _get_scheme_from_parts( raise NotImplementedError("No compressed-tensors compatible scheme was found.") def get_scheme( - self, layer: torch.nn.Module, layer_name: Optional[str] = None + self, layer: torch.nn.Module, layer_name: str | None = None ) -> Optional["CompressedTensorsScheme"]: """ compressed-tensors supports non uniform in the following way: @@ -674,7 +674,7 @@ def get_scheme( sparsity_targets = self.sparsity_scheme_map.keys() - set( self.sparsity_ignore_list ) - sparsity_scheme: Optional[SparsityCompressionConfig] = None + sparsity_scheme: SparsityCompressionConfig | None = None with suppress(ValueError): matched_target = find_matched_target( layer_name=layer_name, @@ -723,7 +723,7 @@ def get_scheme( logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, layer_name) return scheme - def get_cache_scale(self, name: str) -> Optional[str]: + def get_cache_scale(self, name: str) -> str | None: """ Check whether the param name matches the format for k/v cache scales in compressed-tensors. If this is the case, return its equivalent @@ -751,9 +751,9 @@ def has_blocked_weights(self) -> bool: @staticmethod def supports_cutlass_24( - weight_quant: Optional[QuantizationArgs], - input_quant: Optional[QuantizationArgs], - sparsity_scheme: Optional[SparsityCompressionConfig] = None, + weight_quant: QuantizationArgs | None, + input_quant: QuantizationArgs | None, + sparsity_scheme: SparsityCompressionConfig | None = None, ) -> bool: """ Check if the layer is supported by the Cutlass 2:4 Kernel @@ -853,7 +853,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ): """ Use the output of create_weights and the CompressedTensorsScheme @@ -878,7 +878,7 @@ def __init__(self, quant_config: CompressedTensorsConfig): super().__init__(quant_config) @staticmethod - def validate_kv_cache_scheme(kv_cache_scheme: Optional[dict[str, Any]]): + def validate_kv_cache_scheme(kv_cache_scheme: dict[str, Any] | None): """ Validator for the kv cache scheme. Useful for controlling the kv cache quantization schemes, that are being supported in vLLM diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 41e7f1c7a499..bf38c15b4701 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import enum +from collections.abc import Callable from enum import Enum -from typing import Callable, Optional, Union import torch from compressed_tensors import CompressionFormat @@ -34,6 +34,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( is_valid_flashinfer_cutlass_fused_moe, ) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP, @@ -141,9 +142,12 @@ def get_moe_method( # group_size=None means channelwise group_size = weight_quant.group_size or -1 # Prefer to use the MarlinMoE kernel when it is supported. - if not check_moe_marlin_supports_layer(layer, group_size): + if ( + not check_moe_marlin_supports_layer(layer, group_size) + or current_platform.is_rocm() + ): if ( - weight_quant.strategy in QuantizationStrategy.GROUP + weight_quant.strategy == QuantizationStrategy.GROUP and weight_quant.actorder in (ActivationOrdering.GROUP, ActivationOrdering.DYNAMIC) ): @@ -303,10 +307,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight = torch.nn.Parameter( layer.w13_weight_packed.data, requires_grad=False ) + delattr(layer, "w13_weight_packed") layer.w2_weight = torch.nn.Parameter( layer.w2_weight_packed.data, requires_grad=False ) + delattr(layer, "w2_weight_packed") # reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel. if self.allow_flashinfer: @@ -372,7 +378,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: (layer.w2_input_global_scale), requires_grad=False ) - def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: if self.use_marlin: return None elif not self.allow_flashinfer: @@ -399,7 +405,7 @@ def select_gemm_impl( def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: if self.use_marlin: return None @@ -420,21 +426,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if enable_eplb: raise NotImplementedError( "EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet." @@ -462,7 +468,7 @@ def apply( # if self.use_marlin: assert self.fused_experts is None - return torch.ops.vllm.fused_marlin_moe( + return fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, @@ -847,7 +853,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Property to determine if AITER is used if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 - rocm_aiter_fused_experts, shuffle_weights, ) @@ -913,7 +918,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_weight_scale ) - def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: if self.use_marlin or self.rocm_aiter_moe_enabled: return None else: @@ -997,7 +1002,7 @@ def select_gemm_impl( def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: if self.use_marlin: return None @@ -1022,21 +1027,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if enable_eplb: raise NotImplementedError( "EPLB not supported for `CompressedTensorsW8A8Fp8MoEMethod` yet." @@ -1055,6 +1060,7 @@ def apply( routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, + num_fused_shared_experts=layer.num_fused_shared_experts, ) per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN @@ -1067,7 +1073,7 @@ def apply( if self.use_marlin: assert activation == "silu", f"{activation} not supported for Marlin MoE." assert self.fused_experts is None - return torch.ops.vllm.fused_marlin_moe( + return fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, @@ -1280,7 +1286,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: return int8_w8a8_moe_quant_config( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, @@ -1297,21 +1303,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.fused_experts is None if enable_eplb: @@ -1604,7 +1610,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: return None def apply( @@ -1615,21 +1621,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.fused_experts is None if enable_eplb: @@ -1654,7 +1660,7 @@ def apply( indices_type=self.topk_indices_dtype, ) - return torch.ops.vllm.fused_marlin_moe( + return fused_marlin_moe( x, layer.w13_weight_packed, layer.w2_weight_packed, @@ -1856,7 +1862,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: assert self.num_bits == 4 or self.num_bits == 8 config_builder = ( int4_w4a16_moe_quant_config @@ -1880,21 +1886,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.fused_experts is None if enable_eplb: @@ -2092,7 +2098,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def _pack_matrix( int4_as_int8_2d: torch.Tensor, scales_2d: torch.Tensor, - bias_1d: Optional[torch.Tensor], + bias_1d: torch.Tensor | None, in_features: int, out_features: int, ) -> torch.Tensor: @@ -2192,7 +2198,7 @@ def _pack_matrix( def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: # CPU dynamic 4-bit MoE path does not use modular kernels or # fused_experts; quant config is not needed. return None @@ -2205,20 +2211,20 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor: assert not enable_eplb, "EPLB not supported for W4A8-int MoE yet." assert activation in ("silu", "swigluoai", "swiglu"), ( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index fc0634394ece..ca286675ebd0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -15,6 +15,7 @@ from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16 +# This avoids circular import error from .compressed_tensors_24 import CompressedTensors24 # isort: skip __all__ = [ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 068eecf5e026..571ce267f3fa 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any import torch from compressed_tensors import CompressionFormat, ModelCompressor @@ -42,23 +43,23 @@ class CompressedTensors24(CompressedTensorsScheme): def __init__( self, quantized: bool = False, - weight_quant: Optional[QuantizationArgs] = None, - input_quant: Optional[QuantizationArgs] = None, - model_compression_config: Optional[dict[str, Any]] = None, + weight_quant: QuantizationArgs | None = None, + input_quant: QuantizationArgs | None = None, + model_compression_config: dict[str, Any] | None = None, ): self.quantized = quantized self.weight_quant = weight_quant self.input_quant = input_quant - self.model_compressor = ( - ModelCompressor.from_compression_config(model_compression_config) - if model_compression_config is not None - else None + model_compressor = ModelCompressor.from_compression_config( + model_compression_config ) self.do_sparse_decompress = ( - self.model_compressor is not None - and self.model_compressor.sparsity_config.format + model_compressor is not None + and model_compressor.sparsity_config.format == CompressionFormat.sparse_24_bitmask.value ) + if self.do_sparse_decompress: + self.model_compressor = model_compressor if ( quantized @@ -247,7 +248,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: """ Returns the output tensor for the layer with 2:4 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py index 688621cbf79a..a7f9076db7e9 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Optional import torch @@ -33,7 +32,7 @@ def create_weights(self, *args, **kwargs): @abstractmethod def apply_weights( - self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None ): """ Run the forward pass for the particular scheme. This is where diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index af06418c959d..dd0f4b3d868d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from collections.abc import Callable import torch from torch.nn import Parameter @@ -30,7 +30,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): - def __init__(self, strategy: str, num_bits: int, group_size: Optional[int] = None): + def __init__(self, strategy: str, num_bits: int, group_size: int | None = None): self.strategy = strategy self.group_size = group_size self.tile_size = 16 @@ -143,7 +143,7 @@ def create_weights( layer.workspace = workspace def apply_weights( - self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None ) -> torch.Tensor: qweight = layer.weight_packed meta = layer.meta diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py index a96f51538b38..3afadc6eb7e5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from collections.abc import Callable import torch from torch.nn.parameter import Parameter @@ -110,7 +110,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: return apply_fp4_marlin_linear( input=x, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 676f4de6ee7b..4127cd2d574b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from collections.abc import Callable import torch from torch.nn.parameter import Parameter @@ -14,7 +14,10 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 run_nvfp4_emulations, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + cutlass_fp4_supported, + swizzle_blockscale, +) from vllm.model_executor.parameter import ( GroupQuantScaleParameter, ModelWeightParameter, @@ -29,10 +32,12 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): def __init__(self): - if envs.VLLM_USE_TRTLLM_FP4_GEMM: - assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer" - self.backend = "flashinfer-trtllm" - logger.info_once("Using flashinfer-trtllm for FP4") + self.backend = "none" + if envs.VLLM_NVFP4_GEMM_BACKEND is None: + if has_flashinfer(): + self.backend = "flashinfer-cutlass" + elif cutlass_fp4_supported(): + self.backend = "cutlass" elif envs.VLLM_USE_FBGEMM: self.backend = "fbgemm" try: @@ -42,12 +47,17 @@ def __init__(self): "Backend fbgemm requires fbgemm.f4f4bf16 operator, " "Please install with: pip install fbgemm-gpu-genai" ) from exc - logger.info_once("Using FGBEMM-GPU-GENAI for FP4") - elif has_flashinfer(): - self.backend = "flashinfer-cutlass" - logger.info_once("Using flashinfer-cutlass for FP4") - else: - self.backend = "cutlass" + elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"): + self.backend = envs.VLLM_NVFP4_GEMM_BACKEND + assert has_flashinfer(), f"FlashInfer is required for {self.backend}" + + if self.backend == "none": + raise ValueError( + "No valid NVFP4 GEMM backend found. " + "Please check your platform capability." + ) + + logger.info_once(f"Using {self.backend} for NVFP4 GEMM") self.group_size = 16 @classmethod @@ -156,7 +166,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: if envs.VLLM_USE_NVFP4_CT_EMULATIONS: out = run_nvfp4_emulations( @@ -184,10 +194,9 @@ def apply_weights( layer.alpha, output_dtype, ) - if self.backend == "flashinfer-trtllm": - out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") - elif self.backend == "flashinfer-cutlass": - out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass") + if self.backend.startswith("flashinfer-"): + backend_name = self.backend[len("flashinfer-") :] + out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name) elif self.backend == "fbgemm": out = torch.ops.fbgemm.f4f4bf16( x_fp4, @@ -198,6 +207,7 @@ def apply_weights( use_mx=False, ).to(output_dtype) else: + assert self.backend == "cutlass" out = cutlass_scaled_fp4_mm(*mm_args) if bias is not None: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py index 59d99e1e1c90..a23961e89753 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from collections.abc import Callable import torch from compressed_tensors.quantization import ActivationOrdering @@ -41,9 +41,9 @@ def __init__( self, strategy: str, num_bits: int, - group_size: Optional[int] = None, - symmetric: Optional[bool] = True, - actorder: Optional[ActivationOrdering] = None, + group_size: int | None = None, + symmetric: bool | None = True, + actorder: ActivationOrdering | None = None, ): self.pack_factor = 32 // num_bits self.strategy = strategy @@ -178,6 +178,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel.process_weights_after_loading(layer) def apply_weights( - self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None ) -> torch.Tensor: return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py index 61a9f6b75cb1..aa0c52beda2b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from collections.abc import Callable import torch @@ -36,7 +36,7 @@ def __init__( self, strategy: str, num_bits: int, - group_size: Optional[int] = None, + group_size: int | None = None, is_static_input_scheme: bool = False, input_symmetric: bool = True, ): @@ -148,6 +148,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel.process_weights_after_loading(layer) def apply_weights( - self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None ) -> torch.Tensor: return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py index 709d2538e6ad..904a9f5d4907 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from collections.abc import Callable import torch from compressed_tensors.quantization import QuantizationStrategy @@ -125,7 +125,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: return apply_fp8_marlin_linear( input=x, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 902c9c7bde97..ee431c9148b8 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from collections.abc import Callable import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy @@ -179,7 +179,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: if self.weight_block_size is not None: return self.w8a8_block_fp8_linear.apply( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 70316a7553ca..6fd0a6a1c822 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from collections.abc import Callable import torch from compressed_tensors.quantization import QuantizationStrategy @@ -120,6 +120,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel.process_weights_after_loading(layer) def apply_weights( - self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None ) -> torch.Tensor: return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 188fc15fd948..2267395fe67d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from collections.abc import Callable import torch from compressed_tensors.quantization import ActivationOrdering @@ -42,9 +42,9 @@ def __init__( self, strategy: str, num_bits: int, - group_size: Optional[int] = None, - symmetric: Optional[bool] = True, - actorder: Optional[ActivationOrdering] = None, + group_size: int | None = None, + symmetric: bool | None = True, + actorder: ActivationOrdering | None = None, ): self.pack_factor = 32 // num_bits self.strategy = strategy @@ -214,6 +214,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel.process_weights_after_loading(layer) def apply_weights( - self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None ) -> torch.Tensor: return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py index a51fe28b975e..bd1964e667d9 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Generator +from collections.abc import Callable, Generator from itertools import accumulate -from typing import Callable, Optional import torch from compressed_tensors.transform import ( @@ -16,7 +15,6 @@ from vllm.model_executor.layers.linear import ( WEIGHT_LOADER_V2_SUPPORTED, LinearMethodBase, - QKVCrossParallelLinear, ) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsScheme, @@ -39,7 +37,7 @@ class CompressedTensorsLinearTransformMethod(LinearMethodBase): def from_schemes( cls, quant_method: LinearMethodBase, - quant_scheme: Optional[CompressedTensorsScheme], + quant_scheme: CompressedTensorsScheme | None, input_tfms: dict[int, TransformTuple], output_tfms: dict[int, TransformTuple], ) -> "CompressedTensorsLinearTransformMethod": @@ -67,8 +65,8 @@ def __init__( self.input_tfms = input_tfms self.output_tfms = output_tfms - self.input_transform: Optional[HadamardTransform] = None - self.output_transform: Optional[HadamardTransform] = None + self.input_transform: HadamardTransform | None = None + self.output_transform: HadamardTransform | None = None def create_weights( self, @@ -89,10 +87,7 @@ def create_weights( # hack around this by getting weight loader v1 so ULM can load correctly quant_method_name = self.quant_method.__class__.__name__ if quant_method_name not in WEIGHT_LOADER_V2_SUPPORTED: - if isinstance(layer, QKVCrossParallelLinear): - weight_loader_v1 = layer.weight_loader_v1 - else: - weight_loader_v1 = layer.weight_loader + weight_loader_v1 = layer.weight_loader extra_weight_attrs["weight_loader"] = weight_loader_v1 self.quant_method.create_weights( @@ -155,7 +150,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: if self.input_transform is not None: x = self.input_transform(x) @@ -168,7 +163,7 @@ def apply( if self.output_transform is not None: for part_id, (start, length) in enumerate(self.partition_ranges): x[:, start : start + length] = self.output_transform( - x[:, start : start + length].contiguous(), part_id=part_id + x[:, start : start + length].clone(), part_id=part_id ) return x @@ -198,7 +193,7 @@ def _validate_tfm_schemes(self, num_partitions: int): def get_linear_transform_schemes( layer: torch.nn.Module, layer_name: str, - transform_config: Optional[TransformConfig], + transform_config: TransformConfig | None, packed_modules_mapping: dict[str, list[str]], ) -> tuple[ dict[int, TransformTuple], dict[int, TransformTuple] @@ -230,7 +225,7 @@ def get_linear_transform_schemes( def get_schemes_args( - transform_config: Optional[TransformConfig], + transform_config: TransformConfig | None, ) -> Generator[tuple[str, TransformScheme, TransformArgs]]: if transform_config is None: return diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py index ecd798257fce..f5589c8c07fa 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math -from collections.abc import Hashable -from typing import Callable +from collections.abc import Callable, Hashable import torch from compressed_tensors.transform import ( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py index b800c5f5d436..f0bb47a728ad 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -17,7 +16,7 @@ def is_qutlass_fp4_scheme( - quant_scheme: Optional[CompressedTensorsScheme], + quant_scheme: CompressedTensorsScheme | None, input_tfms: dict[int, TransformTuple], ) -> bool: return ( @@ -60,6 +59,6 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: raise NotImplementedError() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py index ed326197295d..25c7d335da20 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -145,7 +144,7 @@ def triton_scaled_mm( scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: type[torch.dtype], - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, block_size_m: int = 32, block_size_n: int = 32, block_size_k: int = 32, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index d8beaafff2ef..f88092169110 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -3,7 +3,6 @@ from collections.abc import Iterable, Mapping from types import MappingProxyType -from typing import Optional import regex as re from compressed_tensors import CompressionFormat @@ -21,7 +20,7 @@ def is_activation_quantization_format(format: str) -> bool: def should_ignore_layer( - layer_name: Optional[str], + layer_name: str | None, ignore: Iterable[str] = tuple(), fused_mapping: Mapping[str, list[str]] = MappingProxyType({}), ) -> bool: @@ -84,7 +83,7 @@ def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool: def find_matched_target( - layer_name: Optional[str], + layer_name: str | None, module: Module, targets: Iterable[str], fused_mapping: Mapping[str, list[str]] = MappingProxyType({}), @@ -134,7 +133,7 @@ def find_matched_target( def _find_first_match( value: str, targets: Iterable[str], check_contains: bool = False -) -> Optional[str]: +) -> str | None: """ Returns first element of target that matches value either exactly or as a regex after 're:'. If check_contains is set to True, @@ -176,7 +175,7 @@ def _match_fused_layer( layer_name: str, target_layers: Iterable[str], fused_mapping: Mapping[str, list[str]], -) -> Optional[str]: +) -> str | None: """ Match a fused layer name to its corresponding individual layer in target_layers. Returns first value in fused_mapping which matches targets @@ -205,7 +204,7 @@ def _match_fused_layer( ] # for each unfused component, find a match in targets - unfused_matches: list[Optional[str]] = [] + unfused_matches: list[str | None] = [] for unfused in unfused_paths: for target in target_layers: if _is_equal_or_regex_match(unfused, target): diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py index 82a2103a19f3..4f742d834573 100644 --- a/vllm/model_executor/layers/quantization/deepspeedfp.py +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -140,7 +140,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: weight = layer.weight y = weight.ds_dequantize() diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 909b04c79f23..754608af97c6 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional import torch @@ -129,7 +130,7 @@ def create_weights( def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: return int8_w8a16_moe_quant_config( w1_scale=layer.w13_scale, w2_scale=layer.w2_scale, w1_zp=None, w2_zp=None ) @@ -142,21 +143,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.fused_experts is None if enable_eplb: diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 5d390cbd7b1e..6ba18e59e4d5 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -171,7 +171,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: if self.quant_config.use_marlin: return apply_fp8_marlin_linear( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 2123fd9eba15..e5681cb85625 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Optional import torch from torch.nn import Module @@ -13,6 +14,9 @@ from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEActivationFormat, @@ -25,6 +29,7 @@ FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.linear import ( LinearBase, @@ -36,6 +41,7 @@ QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, @@ -87,13 +93,15 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -from vllm.utils import has_deep_gemm from vllm.utils.deep_gemm import ( + fp8_gemm_nt, get_col_major_tma_aligned_tensor, is_deep_gemm_e8m0_used, is_deep_gemm_supported, + should_use_deepgemm_for_fp8_linear, ) from vllm.utils.flashinfer import has_flashinfer_moe +from vllm.utils.import_utils import has_deep_gemm if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper @@ -173,8 +181,8 @@ def __init__( self, is_checkpoint_fp8_serialized: bool = False, activation_scheme: str = "dynamic", - ignored_layers: Optional[list[str]] = None, - weight_block_size: Optional[list[int]] = None, + ignored_layers: list[str] | None = None, + weight_block_size: list[int] | None = None, ) -> None: super().__init__() @@ -298,7 +306,7 @@ def get_quant_method( return Fp8KVCacheMethod(self) return None - def get_cache_scale(self, name: str) -> Optional[str]: + def get_cache_scale(self, name: str) -> str | None: """ Check whether the param name matches the format for k/v cache scales in compressed-tensors. If this is the case, return its equivalent @@ -351,6 +359,8 @@ def __init__(self, quant_config: Fp8Config): # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False + if vllm_is_batch_invariant(): + self.use_marlin = False self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() @@ -530,8 +540,115 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: + # if batch invariant mode is enabled, prefer DeepGEMM FP8 path + # we will use BF16 dequant when DeepGEMM is not supported. + if vllm_is_batch_invariant(): + if self.block_quant and should_use_deepgemm_for_fp8_linear( + torch.bfloat16, layer.weight, None + ): + # use group quant consistent with block size across K + assert self.act_q_group_shape is not None + q_input, input_scale = QuantFP8( + False, + self.act_q_group_shape, + column_major_scales=True, + )(x) + + output_2d = torch.empty( + (q_input.shape[0], layer.weight.shape[0]), + dtype=torch.bfloat16, + device=q_input.device, + ) + fp8_gemm_nt( + (q_input, input_scale), + (layer.weight, layer.weight_scale), + output_2d, + ) + if bias is not None: + output_2d = output_2d + bias + return output_2d + + # Dequantize FP8 weights to BF16 + weight_fp8 = layer.weight.to(torch.bfloat16) + weight_scale = layer.weight_scale.to(torch.bfloat16) + + # Handle different quantization granularities + if self.block_quant: + # Block-wise quantization: + # - Weight is NOT transposed, shape is [N, K] (output_size, input_size) + # - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!) + assert self.weight_block_size is not None + block_n, block_k = self.weight_block_size # Note: order is [N, K] + + N, K = weight_fp8.shape + + # determine expected number of blocks along N and K + num_blocks_n = (N + block_n - 1) // block_n + num_blocks_k = (K + block_k - 1) // block_k + + # scale layout may be [num_blocks_n, num_blocks_k] + # or [num_blocks_k, num_blocks_n] depending on backend + if weight_scale.dim() != 2: + raise RuntimeError( + f"FP8 block scale must be 2D, got {tuple(weight_scale.shape)}" + ) + + scale_rows, scale_cols = weight_scale.shape + if (scale_rows, scale_cols) == (num_blocks_k, num_blocks_n): + if num_blocks_n == num_blocks_k: + # ambiguous square case, warn and skip transpose + logger.warning( + "Batch-invariant FP8: square block-scale %dx%d; " + "skipping transpose to avoid misorientation.", + scale_rows, + scale_cols, + ) + else: + # clear KN -> transpose to NK + weight_scale = weight_scale.t() + + # Expand scale to match weight dimensions + # scale_expanded should have shape [N, K] + scale_expanded = weight_scale.repeat_interleave( + block_n, dim=0 + ).repeat_interleave(block_k, dim=1) + # Trim to exact weight size (in case of padding) + scale_expanded = scale_expanded[:N, :K] + weight_bf16 = weight_fp8 * scale_expanded + else: + # Per-tensor quantization: weight IS transposed to [K, N] + # scale should be scalar or [1] or per-output-channel [N] + if weight_scale.numel() == 1: + # Per-tensor: simple scalar multiplication + weight_bf16 = weight_fp8 * weight_scale + else: + # Multiple scales (fused modules like QKV) + # Try to infer correct broadcasting + # weight is [K, N], scale could be [num_logical_weights] + # Need to figure out how to broadcast - for now just try + # direct multiplication + if ( + weight_scale.dim() == 1 + and weight_scale.shape[0] == weight_fp8.shape[0] + ): + # Per-row scaling + weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1) + else: + # Fallback + weight_bf16 = weight_fp8 * weight_scale + + # For block quant, weight is [N, K], for per-tensor it's [K, N] + # F.linear expects weight to be [N, K], so: + if self.block_quant: + # Already in correct shape [N, K] + output = torch.nn.functional.linear(x, weight_bf16, bias) + else: + # Need to transpose back: [K, N] -> [N, K] + output = torch.nn.functional.linear(x, weight_bf16.t(), bias) + return output + if self.use_marlin: return apply_fp8_marlin_linear( input=x, @@ -584,12 +701,12 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): self.weight_block_size = self.quant_config.weight_block_size self.block_quant: bool = self.weight_block_size is not None - self.fused_experts: Optional[mk.FusedMoEModularKernel] = None # type: ignore + self.fused_experts: mk.FusedMoEModularKernel | None = None # type: ignore self.fp8_backend = get_fp8_moe_backend(self.block_quant) self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN - self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None + self.flashinfer_moe_backend: FlashinferMoeBackend | None = None if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: @@ -741,6 +858,8 @@ def create_weights( layer.w13_input_scale = None layer.w2_input_scale = None + self.rocm_aiter_moe_enabled = False + def process_weights_after_loading(self, layer: Module) -> None: # Lazy import to avoid importing triton too early. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( @@ -968,7 +1087,7 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_weight_scale_inv ) - def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: if ( self.rocm_aiter_moe_enabled or self.use_marlin @@ -1041,7 +1160,7 @@ def select_gemm_impl( def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: if self.use_marlin: return None @@ -1067,21 +1186,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if enable_eplb: assert expert_load_view is not None assert logical_to_physical_map is not None @@ -1165,6 +1284,7 @@ def apply( global_num_experts=global_num_experts, zero_expert_num=zero_expert_num, zero_expert_type=zero_expert_type, + num_fused_shared_experts=layer.num_fused_shared_experts, ) # @@ -1193,7 +1313,7 @@ def apply( elif self.use_marlin: assert activation == "silu", f"{activation} not supported for Marlin MoE." assert self.fused_experts is None - result = torch.ops.vllm.fused_marlin_moe( + result = fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, diff --git a/vllm/model_executor/layers/quantization/fp_quant.py b/vllm/model_executor/layers/quantization/fp_quant.py new file mode 100644 index 000000000000..15a253cef0b7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/fp_quant.py @@ -0,0 +1,420 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Supports FP-Quant compression, see https://arxiv.org/abs/2509.23202 + +from typing import Any + +import torch +from torch.nn.parameter import Parameter + +from vllm._custom_ops import ( + cutlass_scaled_fp4_mm, + fusedQuantizeMx, + fusedQuantizeNv, + matmul_mxf4_bf16_tn, +) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op + + +class FPQuantConfig(QuantizationConfig): + """Config class for FPQuant.""" + + def __init__( + self, + hadamard_group_size: int = 32, + forward_dtype: str = "mxfp4", + forward_method: str = "abs_max", + pseudoquantization: bool = False, + modules_to_not_convert: list[str] | None = None, + ) -> None: + super().__init__() + self.hadamard_group_size = hadamard_group_size + self.forward_dtype = forward_dtype + self.forward_method = forward_method + self.pseudoquantization = pseudoquantization + self.modules_to_not_convert = modules_to_not_convert + + if pseudoquantization: + raise ValueError("Pseudoquantization is not supported for vLLM") + + def __repr__(self) -> str: + return ( + f"FPQuantConfig(hadamard_group_size={self.hadamard_group_size}, " + f"forward_dtype={self.forward_dtype}, " + f"forward_method={self.forward_method}, " + f"pseudoquantization={self.pseudoquantization}, " + f"modules_to_not_convert={self.modules_to_not_convert})" + ) + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "fp_quant" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 100 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] # no extra configs. + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "FPQuantConfig": + hadamard_group_size = cls.get_from_keys(config, ["hadamard_group_size"]) + forward_dtype = cls.get_from_keys(config, ["forward_dtype"]) + forward_method = cls.get_from_keys(config, ["forward_method"]) + pseudoquantization = cls.get_from_keys(config, ["pseudoquantization"]) + modules_to_not_convert = cls.get_from_keys(config, ["modules_to_not_convert"]) + return cls( + hadamard_group_size, + forward_dtype, + forward_method, + pseudoquantization, + modules_to_not_convert, + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> LinearMethodBase | None: + if self.modules_to_not_convert is not None and any( + prefix.endswith(module) for module in self.modules_to_not_convert + ): + return UnquantizedLinearMethod() + + if isinstance(layer, LinearBase): + return FPQuantLinearMethod(self) + return None + + +class FPQuantLinearMethod(LinearMethodBase): + """Linear method for FPQuant. + + Args: + quant_config: The FPQuant quantization config. + """ + + def __init__(self, quant_config: FPQuantConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + del input_size # Unused. + + if params_dtype != torch.bfloat16: + raise ValueError("Only bfloat16 is currently supported by FPQuant") + if input_size_per_partition % self.quant_config.hadamard_group_size != 0: # noqa: E501 + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size. Or other skill issues." + ) + + assert self.quant_config.forward_dtype in ["mxfp4", "nvfp4"], ( + "Only mxfp4 and nvfp4 are supported for now" + ) + if self.quant_config.forward_dtype == "mxfp4": + group_size = 32 + elif self.quant_config.forward_dtype == "nvfp4": + group_size = 16 + else: + raise ValueError( + f"Unsupported forward_dtype: {self.quant_config.forward_dtype}" + ) + + qweight = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=torch.uint8, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, + { + "input_dim": 1, + "output_dim": 0, + "packed_dim": 1, + "pack_factor": 2, + } + | extra_weight_attrs, + ) + layer.register_parameter("qweight", qweight) + + scales = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition // group_size, + dtype=torch.uint8, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + "input_dim": 1, + "output_dim": 0, + "packed_dim": 1, + "pack_factor": group_size, + } + | extra_weight_attrs, + ) + layer.register_parameter("scales", scales) + + weight_global_scale = Parameter( + torch.empty(1, dtype=torch.float32), + requires_grad=False, + ) + set_weight_attrs( + weight_global_scale, {"ignore_warning": True} | extra_weight_attrs + ) + layer.register_parameter("weight_global_scale", weight_global_scale) + + act_global_scale = Parameter( + torch.empty(1, dtype=torch.float32), + requires_grad=False, + ) + set_weight_attrs( + act_global_scale, {"ignore_warning": True} | extra_weight_attrs + ) + layer.register_parameter("act_global_scale", act_global_scale) + + forward_hadamard_matrix = Parameter( + torch.empty( + self.quant_config.hadamard_group_size, + self.quant_config.hadamard_group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + forward_hadamard_matrix, {"ignore_warning": True} | extra_weight_attrs + ) + layer.register_parameter("forward_hadamard_matrix", forward_hadamard_matrix) + + backward_hadamard_matrix = Parameter( + torch.empty( + self.quant_config.hadamard_group_size, + self.quant_config.hadamard_group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + backward_hadamard_matrix, {"ignore_warning": True} | extra_weight_attrs + ) + layer.register_parameter("backward_hadamard_matrix", backward_hadamard_matrix) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return quantized_forward( + x, + layer.qweight, + layer.scales, + layer.weight_global_scale, + layer.act_global_scale, + bias, + layer.forward_hadamard_matrix, + self.quant_config.forward_method, + self.quant_config.forward_dtype, + ) + + +def ceil_div(a, b): + return (a + b - 1) // b + + +def fused_quantize_mx( + x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, forward_method: str +) -> tuple[torch.Tensor, torch.Tensor]: + return fusedQuantizeMx(x_flat, hadamard_matrix, method=forward_method) + + +def fused_quantize_mx_fake(x_flat, hadamard_matrix, forward_method): + rows, cols = x_flat.size(0), x_flat.size(1) // 32 + padded_rows = ((rows + 128 - 1) // 128) * 128 + padded_cols = ((cols + 4 - 1) // 4) * 4 + + xh_e2m1 = torch.empty( + x_flat.size(0), x_flat.size(1) // 2, dtype=torch.uint8, device=x_flat.device + ) + xh_e8m0 = torch.empty( + padded_rows, padded_cols, dtype=torch.float8_e8m0fnu, device=x_flat.device + ) + + return xh_e2m1, xh_e8m0 + + +direct_register_custom_op( + op_name="fused_quantize_mx", + op_func=fused_quantize_mx, + mutates_args=[], + fake_impl=fused_quantize_mx_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def matmul_mxf4_bf16( + x: torch.Tensor, + w: torch.Tensor, + xs: torch.Tensor, + ws: torch.Tensor, + alpha: torch.Tensor, +) -> torch.Tensor: + return matmul_mxf4_bf16_tn( + x, + w, + to_blocked(xs, backend="triton").view(torch.float8_e8m0fnu), + to_blocked(ws, backend="triton").view(torch.float8_e8m0fnu), + alpha, + ) + + +def matmul_mxf4_bf16_fake(x, w, xs, ws, alpha): + return torch.empty(*x.shape[:-1], w.shape[0], dtype=torch.bfloat16, device=x.device) + + +direct_register_custom_op( + op_name="matmul_mxf4_bf16", + op_func=matmul_mxf4_bf16, + mutates_args=[], + fake_impl=matmul_mxf4_bf16_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def fused_quantize_nv( + x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, global_scale: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + return fusedQuantizeNv(x_flat, hadamard_matrix, global_scale) + + +def fused_quantize_nv_fake(x_flat, hadamard_matrix, global_scale): + rows, cols = x_flat.size(0), x_flat.size(1) // 16 + padded_rows = ((rows + 128 - 1) // 128) * 128 + padded_cols = ((cols + 4 - 1) // 4) * 4 + + xh_e2m1 = torch.empty( + x_flat.size(0), x_flat.size(1) // 2, dtype=torch.uint8, device=x_flat.device + ) + xh_e8m0 = torch.empty( + padded_rows, padded_cols, dtype=torch.float8_e4m3fn, device=x_flat.device + ) + + return xh_e2m1, xh_e8m0 + + +direct_register_custom_op( + op_name="fused_quantize_nv", + op_func=fused_quantize_nv, + mutates_args=[], + fake_impl=fused_quantize_nv_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def matmul_nvf4_bf16( + x: torch.Tensor, + w: torch.Tensor, + xs: torch.Tensor, + ws: torch.Tensor, + alpha: torch.Tensor, +) -> torch.Tensor: + return cutlass_scaled_fp4_mm( + x, + w, + to_blocked(xs, backend="triton") + .view(torch.float8_e4m3fn) + .view(-1, x.shape[1] // 8), # *2//16 + to_blocked(ws, backend="triton") + .view(torch.float8_e4m3fn) + .view(-1, x.shape[1] // 8), + alpha, + torch.bfloat16, + ) + + +def matmul_nvf4_bf16_fake(x, w, xs, ws, alpha): + return torch.empty(*x.shape[:-1], w.shape[0], dtype=torch.bfloat16, device=x.device) + + +direct_register_custom_op( + op_name="matmul_nvf4_bf16", + op_func=matmul_nvf4_bf16, + mutates_args=[], + fake_impl=matmul_nvf4_bf16_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def quantized_forward( + x: torch.Tensor, + qweight: torch.Tensor, + weight_scales: torch.Tensor, + weight_global_scale: torch.Tensor, + act_global_scale: torch.Tensor, + bias: torch.Tensor | None, + forward_hadamard_matrix: torch.Tensor, + forward_method: str, + forward_dtype: str, +) -> torch.Tensor: + x_flat = x.contiguous().flatten(end_dim=-2) + + if forward_dtype == "mxfp4": + x_flat_q, x_flat_scales = torch.ops.vllm.fused_quantize_mx( + x_flat, forward_hadamard_matrix, forward_method + ) + y = torch.ops.vllm.matmul_mxf4_bf16( + x_flat_q, + qweight, + x_flat_scales, + weight_scales, + 1 / (weight_global_scale * act_global_scale), + ) + elif forward_dtype == "nvfp4": + x_flat_q, x_flat_scales = torch.ops.vllm.fused_quantize_nv( + x_flat, forward_hadamard_matrix, act_global_scale + ) + y = torch.ops.vllm.matmul_nvf4_bf16( + x_flat_q, + qweight, + x_flat_scales, + weight_scales, + 1 / (weight_global_scale * act_global_scale), + ) + else: + raise ValueError(f"Unsupported forward_dtype: {forward_dtype}") + + y = y.view(*x.shape[:-1], y.shape[-1]) + if bias is not None: + y += bias + + return y diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 8296bc2ea3b4..8a914c57a9f7 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional import gguf import torch @@ -27,7 +28,7 @@ ) from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) @@ -35,7 +36,7 @@ class GGUFConfig(QuantizationConfig): """Config class for GGUF.""" - def __init__(self, unquantized_modules: Optional[list[str]] = None) -> None: + def __init__(self, unquantized_modules: list[str] | None = None) -> None: super().__init__() self.unquantized_modules = unquantized_modules or [] @@ -307,7 +308,7 @@ def _apply_gguf_embedding( qweight: torch.Tensor, qweight_type: int, hidden_size: int, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype | None = None, ) -> torch.Tensor: if qweight_type in UNQUANTIZED_TYPES: return torch.embedding(qweight, x) @@ -330,7 +331,7 @@ def _apply_gguf_embedding_fake( qweight: torch.Tensor, qweight_type: int, hidden_size: int, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype | None = None, ) -> torch.Tensor: return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device) @@ -452,7 +453,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: shard_id = layer.qweight.shard_id @@ -558,7 +559,7 @@ def create_weights( def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: return None def apply( @@ -569,21 +570,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.fused_experts is None if enable_eplb: diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 8f36fc70c444..2ad28048cdce 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -4,13 +4,14 @@ import enum from enum import Enum from fractions import Fraction -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from torch.nn.parameter import Parameter from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( @@ -28,13 +29,16 @@ RowvLLMParameter, ) from vllm.transformers_utils.config import get_safetensors_params_metadata -from vllm.utils import is_list_of +from vllm.utils.collection_utils import is_list_of if TYPE_CHECKING: from vllm.model_executor.layers.quantization import QuantizationMethods + from vllm.model_executor.models.utils import WeightsMapper else: QuantizationMethods = str +logger = init_logger(__name__) + class GPTQConfig(QuantizationConfig): """Config class for GPTQ. @@ -48,9 +52,10 @@ def __init__( group_size: int, desc_act: bool, lm_head_quantized: bool, - dynamic: dict[str, dict[str, Union[int, bool]]], + dynamic: dict[str, dict[str, int | bool]], autoround_version: str = "", - modules_in_block_to_quantize: Optional[list[str]] = None, + modules_in_block_to_quantize: list[str] | None = None, + checkpoint_format: str = "", ) -> None: # GPTQModel use `dynamic` config property to allow per module # quantization config so each module can be individually optimized. @@ -88,12 +93,24 @@ def __init__( "Currently, only 2/3/4/8-bit weight quantization is " f"supported for GPTQ, but got {self.weight_bits} bits." ) + # Somehow gptq_gemm 4-bit is buggy, maybe fix it in the future. + # For now, show a warning, since gptq_marlin will be used by default. + if self.weight_bits == 4: + logger.warning_once( + "Currently, the 4-bit gptq_gemm kernel for GPTQ is buggy. " + "Please switch to gptq_marlin or gptq_bitblas." + ) self.modules_in_block_to_quantize = modules_in_block_to_quantize or [] # used to identify GPTQ model quantized by autoround self.autoround_version = autoround_version + # GPTQ v1 and v2 format deals with zero points differently. + # Currently GPTQModel stores v1 format checkpoints by default, + # but provides the option to set `format="gptq_v2"` in `QuantizeConfig`. + self.checkpoint_format = checkpoint_format + def __repr__(self) -> str: return ( f"GPTQConfig(weight_bits={self.weight_bits}, " @@ -101,7 +118,8 @@ def __repr__(self) -> str: f"desc_act={self.desc_act}), " f"lm_head_quantized={self.lm_head_quantized}, " f"dynamic={self.dynamic}, " - f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})" + f"modules_in_block_to_quantize={self.modules_in_block_to_quantize}), " + f"checkpoint_format={self.checkpoint_format})" ) @classmethod @@ -136,6 +154,9 @@ def from_config(cls, config: dict[str, Any]) -> "GPTQConfig": modules_in_block_to_quantize = cls.get_from_keys_or( config, ["modules_in_block_to_quantize"], default=None ) + checkpoint_format = cls.get_from_keys_or( + config, ["checkpoint_format"], default="" + ) return cls( weight_bits, group_size, @@ -144,15 +165,17 @@ def from_config(cls, config: dict[str, Any]) -> "GPTQConfig": dynamic, autoround_version, modules_in_block_to_quantize, + checkpoint_format, ) def get_quant_method( self, layer: torch.nn.Module, prefix: str - ) -> Optional[Union["GPTQLinearMethod", "QuantizeMethodBase"]]: + ) -> Union["GPTQLinearMethod", "QuantizeMethodBase"] | None: if isinstance(layer, FusedMoE): # GPTQ MoE support: fall back to MoeWNA16 for broad compatibility from .moe_wna16 import MoeWNA16Config + # TODO: maybe update this for GPTQv2 format checkpoints config = { "quant_method": "gptq", "bits": self.weight_bits, @@ -164,13 +187,13 @@ def get_quant_method( return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) - def apply_vllm_mapper(self, hf_to_vllm_mapper): + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): if self.modules_in_block_to_quantize is not None: self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list( self.modules_in_block_to_quantize ) - def maybe_update_config(self, model_name: str, revision: Optional[str] = None): + def maybe_update_config(self, model_name: str, revision: str | None = None): if self.modules_in_block_to_quantize: if is_list_of(self.modules_in_block_to_quantize, list): # original modules_in_block_to_quantize: list[list[str]] @@ -209,6 +232,9 @@ class GPTQLinearMethod(LinearMethodBase): def __init__(self, quant_config: GPTQConfig): self.quant_config = quant_config + # GPTQ v1 and v2 format deals with zero points differently + self.use_v2_format = quant_config.checkpoint_format == "gptq_v2" + def create_weights( self, layer: torch.nn.Module, @@ -345,11 +371,13 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: out_shape = x.shape[:-1] + (layer.qweight.shape[-1],) reshaped_x = x.reshape(-1, x.shape[-1]) + # GPTQ v1 and v2 format checkpoints deals with zero points differently, + # and require different gemm kernels. output = ops.gptq_gemm( reshaped_x, layer.qweight, @@ -357,6 +385,7 @@ def apply( layer.scales, layer.g_idx, layer.exllama_state == ExllamaState.READY, + self.use_v2_format, self.quant_config.weight_bits, ) if bias is not None: diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index 85cf4ed4ac58..92f10bfd5c02 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -71,7 +71,7 @@ def __init__( group_size: int, desc_act: bool, is_sym: bool, - quant_method: Optional[str], + quant_method: str | None, lm_head_quantized: bool, ) -> None: try: @@ -180,7 +180,7 @@ def from_config(cls, config: dict[str, Any]) -> "GPTQBitBLASConfig": @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant - ) -> Optional[QuantizationMethods]: + ) -> QuantizationMethods | None: can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg) is_valid_user_quant = ( @@ -474,7 +474,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: out = self.kernel.apply_gptq_bitblas_linear(layer, x) if bias is not None: diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 8fa70a240f9f..0d5439357fda 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable from copy import deepcopy -from typing import Any, Callable, Optional, Union +from typing import Any, Optional import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE @@ -14,6 +15,7 @@ FusedMoEConfig, FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, @@ -55,7 +57,7 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.transformers_utils.config import get_safetensors_params_metadata -from vllm.utils import is_list_of +from vllm.utils.collection_utils import is_list_of logger = init_logger(__name__) @@ -103,9 +105,9 @@ def __init__( desc_act: bool, is_sym: bool, lm_head_quantized: bool, - dynamic: dict[str, dict[str, Union[int, bool]]], + dynamic: dict[str, dict[str, int | bool]], full_config: dict[str, Any], - modules_in_block_to_quantize: Optional[list[str]] = None, + modules_in_block_to_quantize: list[str] | None = None, ) -> None: super().__init__() if desc_act and group_size == -1: @@ -211,7 +213,7 @@ def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig": @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant - ) -> Optional[QuantizationMethods]: + ) -> QuantizationMethods | None: can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) is_valid_user_quant = ( @@ -283,7 +285,7 @@ def apply_vllm_mapper(self, hf_to_vllm_mapper): self.modules_in_block_to_quantize ) - def maybe_update_config(self, model_name: str, revision: Optional[str] = None): + def maybe_update_config(self, model_name: str, revision: str | None = None): if self.modules_in_block_to_quantize: if is_list_of(self.modules_in_block_to_quantize, list): # original modules_in_block_to_quantize: list[list[str]] @@ -459,7 +461,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: return self.kernel.apply_weights(layer, x, bias) @@ -714,7 +716,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: return None def apply( @@ -725,21 +727,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.fused_experts is None if enable_eplb: @@ -764,7 +766,7 @@ def apply( indices_type=self.topk_indices_dtype, ) - return torch.ops.vllm.fused_marlin_moe( + return fused_marlin_moe( x, layer.w13_qweight, layer.w2_qweight, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index 8f0df55b0a5c..2fb614b4746e 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -114,7 +114,7 @@ def from_config(cls, config: dict[str, Any]) -> "GPTQMarlin24Config": @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant - ) -> Optional[QuantizationMethods]: + ) -> QuantizationMethods | None: is_marlin_24_format = hf_quant_cfg.get("checkpoint_format") == "marlin_24" is_valid_user_quant = ( @@ -287,7 +287,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: qweight = layer.B_24 meta = layer.B_meta diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index e61caf6b459b..5fb67c35378b 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -45,7 +45,7 @@ def __init__( self, weight_bits: int, group_size: int, - skip_modules: Optional[list[str]] = None, + skip_modules: list[str] | None = None, ) -> None: super().__init__() assert group_size == 64, "The only supported HQQ group size is currently 64." @@ -327,7 +327,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: workspace = MarlinWorkspace( self.output_size_per_partition, diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index 8786638869a4..7ded8eea7906 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch import torch.nn.functional as F @@ -30,9 +29,9 @@ def __init__( self, static: bool, group_shape: GroupShape, - num_token_padding: Optional[int] = None, + num_token_padding: int | None = None, column_major_scales: bool = False, - use_ue8m0: Optional[bool] = None, # for Torch compile + use_ue8m0: bool | None = None, # for Torch compile ): """ :param static: static or dynamic quantization @@ -64,8 +63,8 @@ def __init__( def forward_cuda( self, x: torch.Tensor, - scale: Optional[torch.Tensor] = None, - scale_ub: Optional[torch.Tensor] = None, + scale: torch.Tensor | None = None, + scale_ub: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if self.is_group_quant: assert scale is None, "Group quantization is always dynamic" @@ -96,8 +95,8 @@ def forward_cuda( def forward_native( self, x: torch.Tensor, - scale: Optional[torch.Tensor] = None, - scale_ub: Optional[torch.Tensor] = None, + scale: torch.Tensor | None = None, + scale_ub: torch.Tensor | None = None, ): if self.is_group_quant: assert scale is None, "Group quantization is always dynamic" diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 4aa0e464e0f5..5b3aabfde0c1 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional import torch from packaging import version @@ -23,12 +24,10 @@ QuantizationConfig, QuantizationMethods, ) -from vllm.model_executor.layers.quantization.awq import ( - AWQLinearMethod, - is_layer_skipped_awq, -) +from vllm.model_executor.layers.quantization.awq import AWQLinearMethod from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8LinearMethod from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -50,9 +49,9 @@ def __init__( method: str, weight_bits: int, group_size: int, - modules_to_not_convert: Optional[list[str]] = None, - desc_act: Optional[bool] = None, - lm_head_quantized: Optional[bool] = None, + modules_to_not_convert: list[str] | None = None, + desc_act: bool | None = None, + lm_head_quantized: bool | None = None, ) -> None: super().__init__() self.method = method @@ -122,7 +121,7 @@ def from_config(cls, config: dict[str, Any]) -> "IPEXConfig": @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant - ) -> Optional[QuantizationMethods]: + ) -> QuantizationMethods | None: if not current_platform.is_cpu() and not current_platform.is_xpu(): return None @@ -138,7 +137,9 @@ def get_quant_method( ) -> Optional["LinearMethodBase"]: if isinstance(layer, LinearBase): if self.method == "awq": - if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + if is_layer_skipped( + prefix, self.modules_to_not_convert, self.packed_modules_mapping + ): return UnquantizedLinearMethod() return IPEXAWQLinearMethod(self) if self.method == "gptq": @@ -206,7 +207,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: reshaped_x = x.reshape(-1, x.shape[-1]) out = layer.ipex_qlinear(reshaped_x) @@ -275,7 +276,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: reshaped_x = x.reshape(-1, x.shape[-1]) out = layer.ipex_qlinear(reshaped_x) @@ -299,7 +300,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: weight = layer.weight.data weight_scale = layer.weight_scale.data @@ -410,7 +411,7 @@ def process_weights_after_loading(self, layer: Module) -> None: def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: return None def apply( @@ -421,20 +422,20 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor: return layer.ipex_fusion( x, diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py index 055a3ebbced6..7aeb1f86c279 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass -from typing import Callable, Optional import torch @@ -20,7 +20,7 @@ class MPLinearLayerConfig: group_size: int zero_points: bool has_g_idx: bool - out_type: Optional[torch.dtype] = None + out_type: torch.dtype | None = None class MPLinearKernel(ABC): @@ -31,7 +31,7 @@ def get_min_capability(cls) -> int: @classmethod @abstractmethod - def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]: raise NotImplementedError def __init__( @@ -39,8 +39,8 @@ def __init__( c: MPLinearLayerConfig, w_q_param_name: str, w_s_param_name: str, - w_zp_param_name: Optional[str] = None, - w_gidx_param_name: Optional[str] = None, + w_zp_param_name: str | None = None, + w_gidx_param_name: str | None = None, ) -> None: assert self.can_implement(c) self.config = c @@ -62,12 +62,12 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: raise NotImplementedError def _transform_param( - self, layer: torch.nn.Module, name: Optional[str], fn: Callable + self, layer: torch.nn.Module, name: str | None, fn: Callable ) -> None: if name is not None and getattr(layer, name, None) is not None: old_param = getattr(layer, name) @@ -83,8 +83,8 @@ def _get_weight_params( ) -> tuple[ torch.Tensor, # w_q torch.Tensor, # w_s - Optional[torch.Tensor], # w_zp, - Optional[torch.Tensor], # w_gidx + torch.Tensor | None, # w_zp, + torch.Tensor | None, # w_gidx ]: return ( getattr(layer, self.w_q_name), diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index 1759d142e6cc..0cf3f12af552 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional - import vllm.envs as envs from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501 AllSparkLinearKernel, @@ -48,7 +46,7 @@ def choose_mp_linear_kernel( - config: MPLinearLayerConfig, compute_capability: Optional[int] = None + config: MPLinearLayerConfig, compute_capability: int | None = None ) -> type[MPLinearKernel]: """ Choose an MPLinearKernel that can implement the given config for the given diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py index c353372b05ec..3baef454251a 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -22,7 +21,7 @@ def get_min_capability(cls) -> int: return 80 @classmethod - def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]: if c.has_g_idx: return False, "Act reordering currently not supported by AllSpark" @@ -87,7 +86,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: c = self.config gemm_args = self.gemm_args diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py index d1ff582c4e21..59c6a4f96154 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch from packaging import version @@ -44,9 +43,9 @@ def __init__( c: MPLinearLayerConfig, w_q_param_name: str, w_s_param_name: str, - w_zp_param_name: Optional[str] = None, - w_gidx_param_name: Optional[str] = None, - bitblas_quant_config: Optional[QuantizationConfig] = None, + w_zp_param_name: str | None = None, + w_gidx_param_name: str | None = None, + bitblas_quant_config: QuantizationConfig | None = None, ): self.quant_config = bitblas_quant_config super().__init__( @@ -57,7 +56,7 @@ def repack_bitblas_from_gptq( self, b_q_weight: torch.Tensor, scales: torch.Tensor, - qzeros: Optional[torch.Tensor] = None, + qzeros: torch.Tensor | None = None, ): from bitblas.quantization.utils import general_compress @@ -82,7 +81,7 @@ def repack_bitblas_from_gptq( # qzeros should be de-quantized to int zeros. weight_bits = quant_config.weight_bits # type: ignore[union-attr] intzeros = unpack_gptq_qzeros(qzeros, weight_bits).T.contiguous() - zeros: Optional[torch.Tensor] = None + zeros: torch.Tensor | None = None zeros_mode = self.bitblas_matmul.config.zeros_mode # type: ignore[attr-defined] if zeros_mode == "original": zeros = intzeros.to(torch.float16).contiguous() @@ -113,7 +112,7 @@ def get_min_capability(cls) -> int: return 70 @classmethod - def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]: is_bitblas_installed = True try: diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py index 281fca7888ab..53b2e15df76d 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from importlib.util import find_spec -from typing import Final, Optional +from typing import Final import torch @@ -26,7 +26,7 @@ def get_min_capability(cls) -> int: return 80 @classmethod - def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]: if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES: error_msg = ( f"Weight type ({c.weight_type}) not supported by " @@ -76,7 +76,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: from conch.ops.quantization.gemm import mixed_precision_gemm diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py index f5df7a244b42..8ef6457c952f 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -26,7 +25,7 @@ def get_min_capability(cls) -> int: return 90 @classmethod - def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]: if not current_platform.is_cuda(): return False, "CUTLASS only supported on CUDA" @@ -95,7 +94,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: c = self.config w_q, w_s, _, _ = self._get_weight_params(layer) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/dynamic_4bit.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/dynamic_4bit.py index 7631236e6f64..d09bd86a7274 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/dynamic_4bit.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/dynamic_4bit.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -20,7 +19,7 @@ def get_min_capability(cls) -> int: return 1 @classmethod - def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]: if not current_platform.is_cpu(): return False, "Only CPU is supported" if c.weight_type not in cls.SUPPORTED_QUANT_TYPES: @@ -95,7 +94,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: c = self.config x_2d = x.reshape(-1, x.shape[-1]) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py index a57d3f65267e..9fba4aafb05a 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -25,7 +24,7 @@ def get_min_capability(cls) -> int: return 60 @classmethod - def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]: if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]: return ( False, @@ -137,7 +136,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: c = self.config @@ -146,10 +145,15 @@ def apply_weights( w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer) + # gptq_gemm supports GPTQv2 format by passing use_v2_format=True. + # However, the MPLinearLayerConfig doesn't contain format info. + # So hardcode GPTQv1 format here, to keep its behavior unchanged. + use_v2_format = False + assert w_zp is not None, "Zero points are required by Exllama" assert w_g_idx is not None, "Group index is required by Exllama" output = ops.gptq_gemm( - x_2d, w_q, w_zp, w_s, w_g_idx, True, c.weight_type.size_bits + x_2d, w_q, w_zp, w_s, w_g_idx, True, use_v2_format, c.weight_type.size_bits ) if bias is not None: diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py index df2f8fedce7e..7953ed5e8ee4 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from functools import partial -from typing import Optional import torch @@ -28,7 +27,7 @@ def get_min_capability(cls) -> int: return 90 @classmethod - def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]: # Machete uses CUTLASS, so it can only be compatible with Nvidia if not current_platform.is_cuda(): return False, "Machete only supported on CUDA" @@ -129,7 +128,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: c = self.config w_q, w_s, w_zp, _ = self._get_weight_params(layer) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py index 0be448e4e3d8..ac21286eeffa 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -32,7 +31,7 @@ def get_min_capability(cls) -> int: return 80 @classmethod - def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]: # Marlin uses inline PTX, so it can only be compatible with Nvidia if not current_platform.is_cuda(): return False, "Marlin only supported on CUDA" @@ -144,7 +143,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: c = self.config w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index d9b999e3d5dd..2a885ec89945 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -3,7 +3,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional import torch @@ -23,7 +22,7 @@ def get_min_capability(cls) -> int: @classmethod @abstractmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: raise NotImplementedError def __init__( @@ -52,7 +51,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: raise NotImplementedError @@ -61,9 +60,9 @@ def _get_weight_params( ) -> tuple[ torch.Tensor, # weight torch.Tensor, # weight_scale - Optional[torch.Tensor], # input_scale, - Optional[torch.Tensor], # input_zp - Optional[torch.Tensor], # azp_adj + torch.Tensor | None, # input_scale, + torch.Tensor | None, # input_zp + torch.Tensor | None, # azp_adj ]: return ( getattr(layer, self.w_q_name), diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index ee5416bae01c..dd59e5d935dc 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import Optional from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( AiterScaledMMLinearKernel, @@ -35,7 +34,7 @@ def choose_scaled_mm_linear_kernel( - config: ScaledMMLinearLayerConfig, compute_capability: Optional[int] = None + config: ScaledMMLinearLayerConfig, compute_capability: int | None = None ) -> type[ScaledMMLinearKernel]: """ Choose an ScaledMMLinearKernel that can implement the given config for the diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index e97beefdd9c2..a19396a162bc 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -1,14 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from .cutlass import CutlassScaledMMLinearKernel from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig @@ -19,7 +18,7 @@ def rocm_aiter_gemm_w8a8_impl( B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: from aiter import gemm_a8w8_CK @@ -36,7 +35,7 @@ def rocm_aiter_gemm_w8a8_fake( B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: m = A.shape[0] @@ -59,7 +58,7 @@ def get_min_capability(cls) -> int: return 90 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not current_platform.is_rocm(): return ( False, @@ -99,7 +98,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: """ `AiterScaledMMLinearKernel` implements a fused version of diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py index cb00b0c8af21..feb1e0bee1aa 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -24,7 +23,7 @@ def get_min_capability(cls) -> int: return 75 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not current_platform.is_cpu(): return False, "CPUScaledMM requires running on CPU." @@ -173,7 +172,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: return self.linear_method( layer, @@ -185,7 +184,7 @@ def _apply_weights_onednn( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) @@ -207,7 +206,7 @@ def _apply_weights_sgl( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: w_q, w_s, _, _, _ = self._get_weight_params(layer) return torch.ops._C.int8_scaled_mm_with_quant( diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index f1dafdf14c7a..e8769916b4ce 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -21,7 +20,7 @@ def get_min_capability(cls) -> int: return 75 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not current_platform.is_cuda(): return False, "CutlassScaledMM requires running on CUDA." @@ -89,8 +88,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # azp_adj is the AZP adjustment term, used to account for weights. # It does not depend on scales or azp, so it is the same for # static and dynamic quantization. - # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md - # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md + # For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md + # https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md if not self.config.input_symmetric: weight = getattr(layer, self.w_q_name) azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) @@ -110,7 +109,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py index 7e21afca5750..3f4ec7f2a738 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -17,7 +16,7 @@ def get_min_capability(cls) -> int: return 75 @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if current_platform.is_cpu(): return ( False, @@ -38,6 +37,6 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: return super().apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 63eee1e28861..ddac9f13cf4f 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import warnings -from typing import Optional import torch from functorch.experimental.control_flow import cond # noqa: F401 @@ -25,7 +24,7 @@ def get_min_capability(cls) -> int: ) @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: if not current_platform.is_tpu(): return False, "ScaledMMXLA requires running on TPU." @@ -77,17 +76,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: message="Pred is a Python constant. When used with torch.cond, it specializes on one of the branches.", # noqa: E501 ) - def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): + def no_add_bias(self, x: torch.Tensor, bias: torch.Tensor | None): return x - def add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): + def add_bias(self, x: torch.Tensor, bias: torch.Tensor | None): return x + bias def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: w_q, w_s, _, _, _ = self._get_weight_params(layer) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index c285b10720d8..0eeeaa3ce457 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Optional import torch from torch.nn import Module @@ -20,6 +21,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( is_valid_flashinfer_cutlass_fused_moe, ) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, @@ -47,6 +49,7 @@ build_flashinfer_fp8_cutlass_moe_prepare_finalize, flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, + is_flashinfer_supporting_global_sf, register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, @@ -70,7 +73,6 @@ ) from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter from vllm.scalar_type import scalar_types -from vllm.utils import next_power_of_2 from vllm.utils.flashinfer import ( flashinfer_scaled_fp4_mm, has_flashinfer, @@ -92,8 +94,8 @@ class ModelOptFp8Config(QuantizationConfig): def __init__( self, is_checkpoint_fp8_serialized: bool = False, - kv_cache_quant_method: Optional[str] = None, - exclude_modules: Optional[list[str]] = None, + kv_cache_quant_method: str | None = None, + exclude_modules: list[str] | None = None, ) -> None: super().__init__() self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized @@ -128,7 +130,7 @@ def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant - ) -> Optional[QuantizationMethods]: + ) -> QuantizationMethods | None: """Detect if this ModelOpt config should be used based on quantization config.""" @@ -319,7 +321,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: return self.fp8_linear.apply( input=x, @@ -351,8 +353,12 @@ def __init__( ) self.cutlass_fp8_supported = cutlass_fp8_supported() - self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None - if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): + self.flashinfer_moe_backend: FlashinferMoeBackend | None = None + if ( + envs.VLLM_USE_FLASHINFER_MOE_FP8 + and has_flashinfer_moe() + and self.moe.is_act_and_mul + ): self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" @@ -360,7 +366,7 @@ def __init__( def maybe_make_prepare_finalize( self, - ) -> Optional[mk.FusedMoEPrepareAndFinalize]: + ) -> mk.FusedMoEPrepareAndFinalize | None: # TRT LLM not supported with all2all yet. if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: return None @@ -403,10 +409,15 @@ def create_weights( ) weight_loader = extra_weight_attrs.get("weight_loader") + if self.moe.is_act_and_mul: + w13_up_dim = 2 * intermediate_size_per_partition + else: + w13_up_dim = intermediate_size_per_partition + w13_weight = ModelWeightParameter( data=torch.empty( num_experts, - 2 * intermediate_size_per_partition, + w13_up_dim, hidden_size, dtype=weight_dtype, ), @@ -431,11 +442,16 @@ def create_weights( if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALES - Per-tensor scaling for ModelOpts - # Allocate 2 scales for w1 and w3 respectively. + # For gated MoE, allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. + # For non-gated MoE, allocate 1 scale for w13. + if self.moe.is_act_and_mul: + w13_weight_scale_shape = (num_experts, 2) + else: + w13_weight_scale_shape = (num_experts, 1) w13_weight_scale = PerTensorScaleParameter( data=torch.full( - (num_experts, 2), + w13_weight_scale_shape, 1.0, dtype=torch.float32, ), @@ -483,7 +499,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max of the w1 and w3 scales # then dequant and requant each expert. - if layer.w13_weight_scale.dim() == 2: + if ( + layer.w13_weight_scale.dim() == 2 + and layer.w13_weight_scale.shape[1] == 2 + ): + assert self.moe.is_act_and_mul, ( + "w13_weight_scale should have 2 elements per expert " + "only for gated MoE" + ) # Get the maximum scale across w1 and w3 for each expert max_w13_scales = layer.w13_weight_scale.max(dim=1).values @@ -541,7 +564,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: return None @@ -561,21 +584,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptFp8MoEMethod` yet." @@ -674,7 +697,7 @@ class ModelOptNvFp4Config(QuantizationConfig): def __init__( self, is_checkpoint_nvfp4_serialized: bool, - kv_cache_quant_algo: Optional[str], + kv_cache_quant_algo: str | None, exclude_modules: list[str], group_size: int = 16, ) -> None: @@ -713,7 +736,7 @@ def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant - ) -> Optional[QuantizationMethods]: + ) -> QuantizationMethods | None: """Detect if this ModelOpt FP4 config should be used based on quantization config.""" if hf_quant_cfg is None: @@ -906,7 +929,7 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): Supports loading kv-cache scaling factors from FP8 checkpoints. """ - def __init__(self, quant_config: Union[ModelOptFp8Config, ModelOptNvFp4Config]): + def __init__(self, quant_config: ModelOptFp8Config | ModelOptNvFp4Config): super().__init__(quant_config) @@ -924,22 +947,26 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptNvFp4Config) -> None: self.quant_config = quant_config - if envs.VLLM_USE_TRTLLM_FP4_GEMM: - assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer" - self.backend = "flashinfer-trtllm" - elif has_flashinfer(): - self.backend = "flashinfer-cutlass" - elif cutlass_fp4_supported(): - self.backend = "cutlass" - elif is_fp4_marlin_supported(): - self.backend = "marlin" - else: + self.backend = "none" + if envs.VLLM_NVFP4_GEMM_BACKEND is None: + if has_flashinfer(): + self.backend = "flashinfer-cutlass" + elif cutlass_fp4_supported(): + self.backend = "cutlass" + elif is_fp4_marlin_supported(): + self.backend = "marlin" + elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"): + self.backend = envs.VLLM_NVFP4_GEMM_BACKEND + assert has_flashinfer(), f"FlashInfer is required for {self.backend}" + + if self.backend == "none": raise ValueError( - "Current platform does not support NVFP4" - " quantization. Please use Blackwell and" - " above." + "No valid NVFP4 GEMM backend found. " + "Please check your platform capability." ) + logger.info_once(f"Using {self.backend} for NVFP4 GEMM") + def create_weights( self, layer: torch.nn.Module, @@ -1071,7 +1098,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: if self.backend == "marlin": return apply_fp4_marlin_linear( @@ -1107,11 +1134,11 @@ def apply( layer.alpha, output_dtype, ) - if self.backend == "flashinfer-trtllm": - out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") - elif self.backend == "flashinfer-cutlass": - out = flashinfer_scaled_fp4_mm(*mm_args, backend="cutlass") + if self.backend.startswith("flashinfer-"): + backend_name = self.backend[len("flashinfer-") :] + out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name) else: + assert self.backend == "cutlass" out = cutlass_scaled_fp4_mm(*mm_args) if bias is not None: @@ -1119,16 +1146,6 @@ def apply( return out.view(*output_shape) -def _get_tile_tokens_dim(num_tokens: int, top_k: int, num_experts: int) -> int: - # Guess tokens per expert assuming perfect expert distribution first. - num_tokens_per_expert = (num_tokens * top_k) // num_experts - # And pad the number to the next power of 2. - tile_tokens_dim = next_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) - return tile_tokens_dim - - class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): """ MoE Method for FP4 Quantization. @@ -1162,7 +1179,7 @@ def __init__( " for ModelOptNvFp4FusedMoE." ) - def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: if self.use_marlin or ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM @@ -1222,6 +1239,7 @@ def create_weights( weight_dtype = torch.uint8 weight_scale_dtype = torch.float8_e4m3fn weight_loader = extra_weight_attrs.get("weight_loader") + global_num_experts = extra_weight_attrs.get("global_num_experts") # GEMM 1 w13_weight = ModelWeightParameter( data=torch.empty( @@ -1300,14 +1318,19 @@ def create_weights( {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} ) + use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf( + self.flashinfer_moe_backend + ) + global_scale_num_experts = global_num_experts if use_global_sf else num_experts + w13_input_scale = PerTensorScaleParameter( - data=torch.empty(num_experts, 2, dtype=torch.float32), + data=torch.empty(global_scale_num_experts, 2, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_input_scale", w13_input_scale) w2_input_scale = PerTensorScaleParameter( - data=torch.empty(num_experts, dtype=torch.float32), + data=torch.empty(global_scale_num_experts, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w2_input_scale", w2_input_scale) @@ -1326,8 +1349,8 @@ def prepare_static_weights_for_trtllm_fp4_moe( ): from flashinfer import nvfp4_block_scale_interleave from flashinfer.fused_moe.core import ( - _maybe_get_cached_w2_permute_indices, _maybe_get_cached_w3_w1_permute_indices, + get_w2_permute_indices_with_cache, ) """Prepare quantized weights for kernel (done offline with weights).""" @@ -1388,7 +1411,7 @@ def prepare_static_weights_for_trtllm_fp4_moe( ) ) - permute_indices = _maybe_get_cached_w2_permute_indices( + permute_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m, @@ -1399,7 +1422,7 @@ def prepare_static_weights_for_trtllm_fp4_moe( .contiguous() ) - permute_sf_indices = _maybe_get_cached_w2_permute_indices( + permute_sf_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m, @@ -1462,7 +1485,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) # Common processing for input scales and alphas - w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) + use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf( + self.flashinfer_moe_backend + ) + if use_global_sf: + # For backends provide by Flashinfer, the input global scales are + # shared across all experts. + w13_input_scale = ( + layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts) + ) + else: + w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) layer.g1_alphas = Parameter( (w13_input_scale * w13_weight_scale_2).to(torch.float32), requires_grad=False, @@ -1474,14 +1507,22 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) # GEMM 2 processing + if use_global_sf: + # For backends provide by Flashinfer, the input global scales are + # shared across all experts. + w2_input_scale = ( + layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts) + ) + else: + w2_input_scale = layer.w2_input_scale layer.g2_alphas = Parameter( - (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), + (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), requires_grad=False, ) # This is for quantization, so we need to invert it. layer.w2_input_scale_quant = Parameter( - (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False + (1 / w2_input_scale).to(torch.float32), requires_grad=False ) # TensorRT-LLM specific processing @@ -1540,23 +1581,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: del layer.w2_input_scale_quant else: # Non-TRT-LLM processing (Cutlass or non-flashinfer) - assert layer.w13_weight_scale.shape[2] % 16 == 0, ( - "Expected weight_scale.dim(1) to be divisible by 16" - ) - assert layer.w13_weight_scale.dtype == torch.float8_e4m3fn, ( - "Weight Blockscale must be represented as FP8-E4M3" - ) w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale) layer.w13_weight_scale = Parameter( w13_blockscale_swizzled, requires_grad=False ) - assert layer.w2_weight_scale.shape[2] % 16 == 0, ( - "Expected weight_scale.dim(1) to be divisible by 16" - ) - assert layer.w2_weight_scale.dtype == torch.float8_e4m3fn, ( - "Weight Blockscale must be represented as FP8-E4M3" - ) w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) layer.w2_weight_scale = Parameter( w2_blockscale_swizzled, requires_grad=False @@ -1565,7 +1594,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: if ( self.use_marlin or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM @@ -1589,21 +1618,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptNvFp4FusedMoE` yet." @@ -1670,9 +1699,7 @@ def apply( local_expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, routed_scaling_factor=None, - tile_tokens_dim=_get_tile_tokens_dim( - x.shape[0], top_k, layer.local_num_experts - ), + tile_tokens_dim=None, routing_method_type=routing_method_type, do_finalize=True, )[0] @@ -1700,7 +1727,7 @@ def apply( # if self.use_marlin: assert self.fused_experts is None - return torch.ops.vllm.fused_marlin_moe( + return fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 3719672f6e52..b0a268b9950b 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional import torch @@ -40,7 +41,7 @@ def __init__( group_size: int, has_zp: bool, lm_head_quantized: bool, - modules_to_not_convert: Optional[list[str]], + modules_to_not_convert: list[str] | None, full_config: dict[str, Any], ) -> None: super().__init__() @@ -127,7 +128,7 @@ def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config": @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant - ) -> Optional[QuantizationMethods]: + ) -> QuantizationMethods | None: can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg) if can_convert and user_quant == "moe_wna16": return cls.get_name() @@ -339,7 +340,7 @@ def create_weights( def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: weight_bits = self.quant_config.weight_bits has_zp = self.quant_config.has_zp assert weight_bits == 4 or weight_bits == 8 @@ -365,21 +366,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.fused_experts is None if enable_eplb: raise NotImplementedError("EPLB not supported for `MoeWNA16Method` yet.") diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index dd9532be7585..6823fa02a32d 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable from enum import Enum -from typing import Callable, Optional, Union +from typing import Optional import torch from torch.nn.parameter import Parameter @@ -17,10 +18,15 @@ from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, + mxfp4_mxfp8_moe_quant_config, mxfp4_w4a16_moe_quant_config, ocp_mx_moe_quant_config, ) -from vllm.model_executor.layers.fused_moe.fused_marlin_moe import MarlinExperts +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + BatchedMarlinExperts, + MarlinExperts, + fused_marlin_moe, +) from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( OAITritonExperts, ) @@ -42,13 +48,10 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -from vllm.utils import ( - has_triton_kernels, - is_torch_equal_or_newer, - next_power_of_2, - round_up, -) +from vllm.utils import round_up from vllm.utils.flashinfer import has_flashinfer +from vllm.utils.import_utils import has_triton_kernels +from vllm.utils.torch_utils import is_torch_equal_or_newer logger = init_logger(__name__) @@ -92,12 +95,6 @@ def get_mxfp4_backend(): and has_flashinfer() and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 ): - logger.info_once( - "Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, " - "for high concurrency throughput workloads consider setting " - "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better " - "performance" - ) return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM elif current_platform.is_device_capability(100) and has_flashinfer(): logger.info_once( @@ -137,7 +134,7 @@ def get_mxfp4_backend(): class Mxfp4Config(QuantizationConfig): - def __init__(self, ignored_layers: Optional[list[str]] = None): + def __init__(self, ignored_layers: list[str] | None = None): super().__init__() self.ignored_layers = ignored_layers @@ -188,7 +185,7 @@ def __init__(self, moe: FusedMoEConfig): self.moe = moe self.mxfp4_backend = get_mxfp4_backend() self.max_capture_size = ( - get_current_vllm_config().compilation_config.max_capture_size + get_current_vllm_config().compilation_config.max_cudagraph_capture_size ) assert self.mxfp4_backend != Mxfp4Backend.NONE, ( @@ -352,7 +349,7 @@ def process_weights_after_loading(self, layer): or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 ): from flashinfer.fp4_quantization import nvfp4_block_scale_interleave - from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices + from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache layer.gemm1_alpha = Parameter( torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(), @@ -444,7 +441,7 @@ def swap_every_two_rows(x, axis=-1): epilogue_tile_m = 128 # FIXME: this depends on the kernel internals for i in range(self.num_experts): # w13 weight shuffling - permute_indices = _maybe_get_cached_w2_permute_indices( + permute_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, w13_weight[i].view(torch.uint8), epilogue_tile_m, @@ -455,7 +452,7 @@ def swap_every_two_rows(x, axis=-1): .contiguous() ) # w13 scale shuffling - permute_sf_indices = _maybe_get_cached_w2_permute_indices( + permute_sf_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, w13_weight_scale[i].view(torch.uint8), epilogue_tile_m, @@ -471,7 +468,7 @@ def swap_every_two_rows(x, axis=-1): ) ) # w13 bias shuffling - permute_bias_indices = _maybe_get_cached_w2_permute_indices( + permute_bias_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, w13_bias[i].clone().reshape(-1, 1), epilogue_tile_m, @@ -483,7 +480,7 @@ def swap_every_two_rows(x, axis=-1): .contiguous() ) # w2 weight shuffling - permute_indices = _maybe_get_cached_w2_permute_indices( + permute_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, w2_weight[i].view(torch.uint8), epilogue_tile_m, @@ -494,7 +491,7 @@ def swap_every_two_rows(x, axis=-1): .contiguous() ) # w2 scale shuffling - permute_sf_indices = _maybe_get_cached_w2_permute_indices( + permute_sf_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, w2_weight_scale[i].view(torch.uint8), epilogue_tile_m, @@ -510,7 +507,7 @@ def swap_every_two_rows(x, axis=-1): ) ) # w2 bias shuffling - permute_indices = _maybe_get_cached_w2_permute_indices( + permute_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, w2_bias[i].clone().reshape(-1, 1), epilogue_tile_m, @@ -730,33 +727,9 @@ def _interleave_mxfp4_cutlass_sm90(w): else: raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") - def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int): - # Number of tokens in the input tensor. - num_tokens = x.shape[0] - # Factor to account for the imbalance of the experts. - # factor equals to the - # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert - # - 1.0 means perfect expert distribution. - # - > 1.0 means some experts have more - # tokens than the perfect distribution. - # - < 1.0 does not make sense. - imbalance_factor = 1.3 - # Calculate the number of tokens per expert - # assuming perfect distribution. - num_tokens_per_expert = (num_tokens * top_k) // self.num_experts - # Apply the imbalance factor. - num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) - # And pad the number to the next power of 2. - tile_tokens_dim = next_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile - # as it's the range supported by the kernel. - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) - - return tile_tokens_dim - def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: if self.mxfp4_backend == Mxfp4Backend.MARLIN: return mxfp4_w4a16_moe_quant_config( w1_bias=layer.w13_bias, @@ -773,6 +746,23 @@ def get_fused_moe_quant_config( w1_scale=w1_scale, w2_scale=w2_scale, ) + elif self.mxfp4_backend in [ + Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, + Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS, + ]: + return mxfp4_mxfp8_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + ) + elif self.mxfp4_backend in [Mxfp4Backend.SM100_FI_MXFP4_BF16]: + return mxfp4_w4a16_moe_quant_config( + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + ) else: w1_scale = layer.w13_weight_scale w2_scale = layer.w2_weight_scale @@ -793,9 +783,20 @@ def select_gemm_impl( prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts ): - raise NotImplementedError( - "Mxfp4 does not support batched experts format for EP" - ) + if self.mxfp4_backend == Mxfp4Backend.MARLIN: + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens_per_rank is not None + assert self.moe_quant_config is not None + return BatchedMarlinExperts( + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, + ) + else: + raise NotImplementedError( + f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for " + "EP batched experts format" + ) else: assert self.moe_quant_config is not None if ( @@ -813,8 +814,12 @@ def select_gemm_impl( return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs) elif self.mxfp4_backend == Mxfp4Backend.MARLIN: return MarlinExperts(self.moe_quant_config) - else: + elif self.mxfp4_backend == Mxfp4Backend.TRITON: return OAITritonExperts(self.moe_quant_config) + else: + raise NotImplementedError( + f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP" + ) def _route_and_experts( self, @@ -824,19 +829,19 @@ def _route_and_experts( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor: assert isinstance(self.fused_experts, mk.FusedMoEModularKernel) @@ -890,21 +895,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if enable_eplb: raise NotImplementedError("EPLB is not supported for mxfp4") @@ -946,7 +951,7 @@ def apply( e_score_correction_bias=e_score_correction_bias, ) - return torch.ops.vllm.fused_marlin_moe( + return fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, @@ -1022,7 +1027,7 @@ def apply( layer.ep_rank * layer.local_num_experts, # local_expert_offset self.num_experts, # local num experts None, - self._get_tile_tokens_dim(x, top_k), + None, 1 if renormalize else 0, # routing_method_type, renormalize True, # do finalize tune_max_num_tokens=self.max_capture_size, diff --git a/vllm/model_executor/layers/quantization/petit.py b/vllm/model_executor/layers/quantization/petit.py index 60519bdaea02..402cebc38c21 100644 --- a/vllm/model_executor/layers/quantization/petit.py +++ b/vllm/model_executor/layers/quantization/petit.py @@ -41,9 +41,9 @@ class PetitNvFp4Config(QuantizationConfig): def __init__( self, is_checkpoint_nvfp4_serialized: bool = False, - kv_cache_quant_algo: Optional[str] = None, - group_size: Optional[int] = None, - exclude_modules: Optional[list[str]] = None, + kv_cache_quant_algo: str | None = None, + group_size: int | None = None, + exclude_modules: list[str] | None = None, ) -> None: self._check_hardware_support() self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized @@ -133,7 +133,7 @@ def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config": @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant - ) -> Optional[QuantizationMethods]: + ) -> QuantizationMethods | None: if not current_platform.is_rocm(): return None @@ -307,7 +307,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: return apply_petit_nvfp4_linear( input=x, diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index c0156321f65d..26ba8e5b16bc 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -34,7 +34,7 @@ class PTPCFp8Config(Fp8Config): def __init__( self, activation_scheme: str = "dynamic", - ignored_layers: Optional[list[str]] = None, + ignored_layers: list[str] | None = None, ) -> None: if not current_platform.is_rocm(): raise ValueError("ptpc_fp8 quantization is supported only on ROCm.") @@ -125,7 +125,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: return self.fp8_linear.apply( input=x, diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 51f9d56121bd..d5459594b798 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -43,8 +43,8 @@ class QuarkConfig(QuantizationConfig): def __init__( self, quant_config: dict[str, Any], - kv_cache_group: Optional[list[str]] = None, - kv_cache_config: Optional[dict[str, Any]] = None, + kv_cache_group: list[str] | None = None, + kv_cache_config: dict[str, Any] | None = None, pack_method: str = "reorder", ): super().__init__() @@ -178,8 +178,8 @@ def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bo def _is_fp8_w8a8( self, - weight_quant: Optional[dict[str, Any]], - input_quant: Optional[dict[str, Any]], + weight_quant: dict[str, Any] | None, + input_quant: dict[str, Any] | None, ) -> bool: # Confirm weights and input quantized. if weight_quant is None or input_quant is None: @@ -209,8 +209,8 @@ def _is_fp8_w8a8( def _is_static_tensor_w8a8( self, - weight_quant: Optional[dict[str, Any]], - input_quant: Optional[dict[str, Any]], + weight_quant: dict[str, Any] | None, + input_quant: dict[str, Any] | None, ) -> bool: # Confirm weights and input quantized. if weight_quant is None or input_quant is None: @@ -237,8 +237,8 @@ def _is_static_tensor_w8a8( def _is_ocp_mx( self, - weight_quant: Optional[dict[str, Any]], - input_quant: Optional[dict[str, Any]], + weight_quant: dict[str, Any] | None, + input_quant: dict[str, Any] | None, ) -> bool: # Confirm weights and input quantized. if weight_quant is None or input_quant is None: @@ -370,7 +370,7 @@ def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme": return scheme - def get_cache_scale(self, name: str) -> Optional[str]: + def get_cache_scale(self, name: str) -> str | None: """ Check whether the param name matches the format for k/v cache scales in quark. If this is the case, return its equivalent param name @@ -429,7 +429,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ): """ Use the output of create_weights and the CompressedTensorsScheme @@ -454,7 +454,7 @@ def __init__(self, quant_config: QuarkConfig): super().__init__(quant_config) @staticmethod - def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]): + def validate_kv_cache_config(kv_cache_config: dict[str, Any] | None): """ Validator for the kv cache configuration. Useful for controlling the kv cache quantization schemes, that are being supported in vLLM diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index f00188a6f8c4..a8f4b1b0db68 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any import torch @@ -19,8 +20,10 @@ fp8_w8a8_moe_quant_config, ocp_mx_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled, + use_mxfp4_aiter_moe, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_moe_fp8_layer_for_marlin, @@ -333,13 +336,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: return fp8_w8a8_moe_quant_config( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - per_act_token_quant=self.weight_qscheme == "per_channel", + per_act_token_quant=self.input_qscheme == "per_channel", + per_out_ch_quant=self.weight_qscheme == "per_channel", ) def apply( @@ -350,21 +354,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.fused_experts is None if enable_eplb: @@ -401,7 +405,7 @@ def apply( ) if self.use_marlin: assert activation == "silu", f"{activation} not supported for Marlin MoE." - return torch.ops.vllm.fused_marlin_moe( + return fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, @@ -470,22 +474,22 @@ def __init__( "not implemented. Please open an issue." ) - if not current_platform.supports_mx(): - self.emulate = True + self.emulate = not current_platform.supports_mx() or not ( + use_mxfp4_aiter_moe() and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4" + ) + if self.emulate: logger.warning_once( - "The current platform does not support native MXFP4/MXFP6 " + f"The current mode (supports_mx={current_platform.supports_mx()}, " + f"use_mxfp4_aiter_moe={use_mxfp4_aiter_moe()}, " + f"ocp_mx_scheme={self.ocp_mx_scheme}) " + "does not support native MXFP4/MXFP6 " "computation. Simulated weight dequantization and activation " "QDQ (quantize and dequantize) will be used, with the linear " "layers computed in high precision." ) else: - self.emulate = True logger.warning_once( - "The current platform supports native MXFP4/MXFP6 " - "computation, but kernels are not yet integrated in vLLM. " - "Simulated weight dequantization and activation " - "QDQ (quantize and dequantize) will be used, with the linear " - "layers computed in high precision." + "The current mode supports native MoE MXFP4 computation" ) def get_packed_dim(self, dim: int, quant_dtype: str): @@ -566,9 +570,27 @@ def create_weights( layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) + def process_weights_after_loading(self, layer): + if self.emulate: + return + + from aiter.utility.fp4_utils import e8m0_shuffle + + # Pre-shuffle weight scales + s0, s1, _ = layer.w13_weight_scale.shape + w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1) + w13_weight_scale = e8m0_shuffle(w13_weight_scale) + layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1) + + s0, s1, _ = layer.w2_weight_scale.shape + w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1) + w2_weight_scale = e8m0_shuffle(w2_weight_scale) + layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1) + torch.cuda.empty_cache() + def get_fused_moe_quant_config( self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: + ) -> FusedMoEQuantConfig | None: return ocp_mx_moe_quant_config( quant_dtype=self.input_dtype, weight_dtype=self.weight_dtype, @@ -587,21 +609,21 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.fused_experts is None if enable_eplb: @@ -609,8 +631,6 @@ def apply( "EPLB not supported for `QuarkOCP_MX_MoEMethod` yet." ) - from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -626,17 +646,44 @@ def apply( indices_type=self.topk_indices_dtype, ) - out = fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - global_num_experts=global_num_experts, - apply_router_weight_on_input=apply_router_weight_on_input, - expert_map=expert_map, - quant_config=self.moe_quant_config, - ) + if not self.emulate: + from aiter import ActivationType, QuantType + from aiter.fused_moe import fused_moe + + aiter_acts = { + ActivationType.No.name.lower(): ActivationType.No, + ActivationType.Silu.name.lower(): ActivationType.Silu, + ActivationType.Gelu.name.lower(): ActivationType.Gelu, + } + assert activation in aiter_acts, ( + f"Aiter CK fp4 MoE doesn't support activation {activation}" + ) + out = fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + quant_type=QuantType.per_1x32, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + activation=aiter_acts[activation], + doweight_stage1=False, + ) + else: + from vllm.model_executor.layers.fused_moe import fused_experts + + out = fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + quant_config=self.moe_quant_config, + ) return out diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index 0eefa7f7e96c..c25c522dea55 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable from fractions import Fraction from functools import cache, partial -from typing import Any, Callable, Optional, Union +from typing import Any import torch import torch.nn.functional as F @@ -44,7 +45,7 @@ def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 from aiter.ops.triton.quant import dynamic_mxfp4_quant - from vllm.utils import direct_register_custom_op + from vllm.utils.torch_utils import direct_register_custom_op if is_rocm_aiter_fp4_asm_gemm_enabled(): from aiter import gemm_a4w4, per_1x32_f4_quant_hip @@ -54,8 +55,8 @@ def gemm_with_dynamic_quant( weight: torch.Tensor, weight_scale: torch.Tensor, rocm_use_aiter_fp4_asm_gemm: bool = False, - out_dtype: Optional[torch.dtype] = torch.bfloat16, - x_scales: Optional[torch.Tensor] = None, + out_dtype: torch.dtype | None = torch.bfloat16, + x_scales: torch.Tensor | None = None, ) -> torch.Tensor: M = x.shape[0] if rocm_use_aiter_fp4_asm_gemm: @@ -95,7 +96,7 @@ def gemm_with_dynamic_quant_fake( weight_scale: torch.Tensor, x_scales: torch.Tensor = None, rocm_use_aiter_fp4_asm_gemm: bool = False, - out_dtype: Optional[torch.dtype] = torch.bfloat16, + out_dtype: torch.dtype | None = torch.bfloat16, ) -> torch.Tensor: return torch.empty( (*x.shape[:-1], weight.shape[0]), dtype=out_dtype, device=x.device @@ -129,7 +130,7 @@ def __init__( ) if self.weight_dtype == "mxfp4": - self.packed_factor: Union[int, Fraction] = 2 + self.packed_factor: int | Fraction = 2 self.dequant_func = dequant_mxfp4 else: self.packed_factor = Fraction(numerator=8, denominator=6) @@ -282,7 +283,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: if self.emulate: dq_w = self.dequant_func(layer.weight, layer.weight_scale, x.dtype) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py index ddec0f6ea8eb..412a07a85fe7 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Optional import torch @@ -33,7 +32,7 @@ def create_weights(self, *args, **kwargs): @abstractmethod def apply_weights( - self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None ): """ Run the forward pass for the particular scheme. This is where diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 553698a7dc94..1e5ee93b61f2 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional, cast +from collections.abc import Callable +from typing import Any, cast import torch from torch.nn import Parameter @@ -25,11 +26,11 @@ class QuarkW8A8Fp8(QuarkScheme): def __init__( - self, weight_config: dict[str, Any], input_config: Optional[dict[str, Any]] + self, weight_config: dict[str, Any], input_config: dict[str, Any] | None ): self.weight_qscheme = cast(str, weight_config.get("qscheme")) self.is_static_input_scheme: bool = False - self.input_qscheme: Optional[str] = None + self.input_qscheme: str | None = None if input_config is not None: self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic")) self.input_qscheme = cast(str, input_config.get("qscheme")) @@ -166,7 +167,7 @@ def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: return self.fp8_linear.apply( input=x, diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py index c41dd05d1062..42d2ed2e85ed 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional +from collections.abc import Callable import torch @@ -27,8 +27,8 @@ class QuarkW8A8Int8(QuarkScheme): def __init__( self, qscheme: str, - is_static_input_scheme: Optional[bool], - input_symmetric: Optional[bool], + is_static_input_scheme: bool | None, + input_symmetric: bool | None, ): self.qscheme = qscheme self.is_static_input_scheme = is_static_input_scheme @@ -134,6 +134,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel.process_weights_after_loading(layer) def apply_weights( - self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None ) -> torch.Tensor: return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py index 0eb4b20a6e52..dc82f94ebbbf 100644 --- a/vllm/model_executor/layers/quantization/quark/utils.py +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -3,7 +3,7 @@ from collections.abc import Iterable, Mapping from types import MappingProxyType -from typing import Any, Optional +from typing import Any import regex as re @@ -22,7 +22,7 @@ def deep_compare(dict1: Any, dict2: Any) -> bool: def should_ignore_layer( - layer_name: Optional[str], + layer_name: str | None, ignore: Iterable[str], fused_mapping: Mapping[str, list[str]] = MappingProxyType({}), ) -> bool: diff --git a/vllm/model_executor/layers/quantization/qutlass_utils.py b/vllm/model_executor/layers/quantization/qutlass_utils.py new file mode 100644 index 000000000000..555bb50da199 --- /dev/null +++ b/vllm/model_executor/layers/quantization/qutlass_utils.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Modified by Roberto L. Castro (Roberto.LopezCastro@ist.ac.at). +# +# Copied from https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Literal + +import torch +from torch.library import wrap_triton + +from vllm.triton_utils import tl, triton + + +@triton.jit +def triton_scale_swizzle( + scale_ptr: torch.Tensor, + scale_rows: int, + scale_cols: int, + output_ptr: torch.Tensor, + input_row_stride: int, + output_block_stride: int, + BLOCK_ROWS: tl.constexpr, + BLOCK_COLS: tl.constexpr, +): + """ + Rearranges tensor data from row-major to block-scaled swizzle format. + + Args: + scale_ptr: Pointer to the input scale tensor + scale_rows: Number of rows in the scale tensor + scale_cols: Number of columns in the scale tensor + output_ptr: Pointer to the output tensor + input_row_stride: Stride between rows in the input tensor + output_block_stride: Stride between blocks in the output tensor + BLOCK_ROWS: Number of rows in a tile (compile-time constant) + BLOCK_COLS: Number of columns in a tile (compile-time constant) + """ + pid_row = tl.program_id(0) + pid_col = tl.program_id(1) + + rows = tl.arange(0, BLOCK_ROWS)[:, None] + cols = tl.arange(0, BLOCK_COLS)[None, :] + + # Calculate starting row and column for this tile + start_row = pid_row * BLOCK_ROWS + start_col = pid_col * BLOCK_COLS + global_rows = start_row + rows + global_cols = start_col + cols + + mask = (global_rows < scale_rows) & (global_cols < scale_cols) + + input_scales = tl.load( + scale_ptr + global_rows * input_row_stride + global_cols, + mask=mask, + other=0.0, + ) + + r_div_32 = rows // 32 + r_mod_32 = rows % 32 + + # 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates + dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols + + # Flatten + dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS)) + scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS)) + + # Calculate block offset using provided output block stride + LOCAL_NUMEL = BLOCK_ROWS * BLOCK_COLS + block_offset = pid_col * LOCAL_NUMEL + (pid_row * output_block_stride) + + tl.store( + output_ptr + block_offset + dest_indices_flat, + scales_flat, + ) + + +def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor: + """ + Rearranges an E8M0 tensor scale from row-major format to + block-scaled swizzle format. + + This format is suitable for Tmem as described in NVIDIA documentation: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + scale_tensor: Input tensor in row-major format with 8-bit elements + + Returns: + Rearranged tensor in block-scaled swizzle format + """ + assert scale_tensor.element_size() == 1, ( + "Expected element size to be 1 byte (8 bits)" + ) + assert scale_tensor.is_contiguous(), "Input tensor must be contiguous" + + rows, cols = scale_tensor.shape + + # Calculate blocks needed + n_row_blocks = triton.cdiv(rows, 128) + n_col_blocks = triton.cdiv(cols, 4) + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + out = scale_tensor.new_empty((padded_rows, padded_cols)) + + # Input stride (for row-major format) + input_row_stride = cols + + # We probably want handle multiple blocks per tile but + # for now keep it simple + BLOCK_ROWS, BLOCK_COLS = 128, 4 + + # Output block stride for the rearranged format + output_block_stride = BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS) + + grid = lambda META: ( + triton.cdiv(padded_rows, BLOCK_ROWS), + triton.cdiv(padded_cols, BLOCK_COLS), + ) + + wrap_triton(triton_scale_swizzle)[grid]( + scale_tensor.view(torch.uint8), + rows, + cols, + out.view(torch.uint8), + input_row_stride, + output_block_stride, + BLOCK_ROWS=BLOCK_ROWS, + BLOCK_COLS=BLOCK_COLS, + ) + + return out + + +def ceil_div(a, b): + return (a + b - 1) // b + + +def to_blocked( + input_matrix: torch.Tensor, backend: Literal["torch", "triton"] = "triton" +) -> torch.Tensor: + """ + Rearrange a large matrix by breaking it into blocks and applying + the rearrangement pattern. + + See: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + input_matrix: Input tensor of shape (H, W) + backend: "torch" (PyTorch path) or "triton" (Triton kernel) + + Returns: + Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) + """ + if backend == "triton": + return triton_mx_block_rearrange(input_matrix).flatten() + elif backend != "torch": + raise ValueError(f'backend must be "torch" or "triton", got {backend!r}') + + rows, cols = input_matrix.shape + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + # Calculate the padded shape + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + padded = input_matrix + assert (rows, cols) == (padded_rows, padded_cols) + + # Rearrange the blocks + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + return rearranged.flatten() diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index e0070e207048..e4f7ff833956 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -3,23 +3,20 @@ # Copyright © 2025, Oracle and/or its affiliates. import os -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional +import numpy as np import torch -import torch.nn.functional as F from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import ( - FusedMoE, - FusedMoEConfig, - FusedMoEMethodBase, -) from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEQuantConfig, - int4_w4a16_moe_quant_config, - int8_w8a16_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe +from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, @@ -30,6 +27,12 @@ QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + apply_rtn_marlin_linear, + marlin_make_workspace_new, +) +from vllm.scalar_type import scalar_types logger = init_logger(__name__) """By default, use 8 bit as target precision, but it can be @@ -40,6 +43,9 @@ overridden by setting the RTN_GROUP_SIZE envvar """ GROUP_SIZE = os.getenv("RTN_GROUP_SIZE", "128") +"""Global Marlin workspace shared by all modules +""" +workspace = None class RTNConfig(QuantizationConfig): @@ -59,6 +65,10 @@ def __init__( f"supported for RTN, but got {self.weight_bits} bits." ) + self.quant_type = ( + scalar_types.uint8b128 if self.weight_bits == 8 else scalar_types.uint4b8 + ) + def __repr__(self) -> str: return ( f"RTNConfig(weight_bits={self.weight_bits}, group_size={self.group_size})" @@ -220,24 +230,32 @@ def create_weights( layer.output_size_per_partition = output_size_per_partition def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - fix_weights(layer, "weight") + """Repack weights and scales for Marlin kernels.""" + weight_bits = self.quant_config.weight_bits + + weight, scale = repack_weights(layer.weight, layer.scale, weight_bits) + + replace_parameter(layer, "weight", weight) + replace_parameter(layer, "scale", scale) + + init_workspace(layer.weight.device) def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: - qweight = layer.weight - scale = layer.scale - - weight = rtn_dequantize(qweight, scale) - out = F.linear(x, weight) - del weight - if bias is not None: - out.add_(bias) - - return out + return apply_rtn_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.scale, + workspace=workspace, + quant_type=self.quant_config.quant_type, + output_size_per_partition=layer.output_size_per_partition, + input_size_per_partition=layer.input_size_per_partition, + bias=bias, + ) class RTNMoEMethod(FusedMoEMethodBase): @@ -314,28 +332,27 @@ def create_weights( set_weight_attrs(w2_weight, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Repack weights and scales for Marlin kernels.""" weight_bits = self.quant_config.weight_bits - fix_weights(layer, "w13_weight", weight_bits == 4) - fix_weights(layer, "w2_weight", weight_bits == 4) - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> Optional[FusedMoEQuantConfig]: - weight_bits = self.quant_config.weight_bits - group_size = self.quant_config.group_size - assert weight_bits == 4 or weight_bits == 8 - config_builder = ( - int4_w4a16_moe_quant_config - if weight_bits == 4 - else int8_w8a16_moe_quant_config + w13_weight, w13_scale = repack_weights( + layer.w13_weight, layer.w13_scale, weight_bits ) - return config_builder( - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale, - w1_zp=None, - w2_zp=None, - block_shape=[0, group_size], + replace_parameter(layer, "w13_weight", w13_weight) + replace_parameter(layer, "w13_scale", w13_scale) + + w2_weight, w2_scale = repack_weights( + layer.w2_weight, layer.w2_scale, weight_bits ) + replace_parameter(layer, "w2_weight", w2_weight) + replace_parameter(layer, "w2_scale", w2_scale) + + init_workspace(layer.w13_weight.device) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + return None def apply( self, @@ -345,28 +362,26 @@ def apply( top_k: int, renormalize: bool, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.fused_experts is None if enable_eplb: raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.") - from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -382,18 +397,22 @@ def apply( indices_type=self.topk_indices_dtype, ) - return fused_experts( + return fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, + getattr(layer, "w13_bias", None), + getattr(layer, "w2_bias", None), + layer.w13_scale, + layer.w2_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=self.quant_config.quant_type.id, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, - quant_config=self.moe_quant_config, + workspace=workspace, ) @@ -503,18 +522,133 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: return input_deq -def fix_weights(layer: torch.nn.Module, param_name: str, reshape: bool = False): - """torch.compile does not know how to deal with a Parameter subclass - (aka RTNParameter). As we don't really need RTNParameters for the - forward pass, we replace them with equivalent instances of Parameters. +def _get_perms(): + perm = [] + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm.extend([p + 256 * j for p in perm1]) + + perm_arr = np.array(perm) + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + perm_arr = perm_arr.reshape((-1, 8))[:, interleave].ravel() + perm_tensor = torch.from_numpy(perm_arr) + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return perm_tensor, scale_perm, scale_perm_single + + +_perm, _scale_perm, _scale_perm_single = _get_perms() + + +def pack_for_marlin(weight, scale, qbits): + batch = weight.shape[0] + + n = weight.size(1) + k = weight.size(2) + groupsize = k // scale.size(2) + + tile = 16 + s = scale.permute(0, 2, 1) # transpose + w = weight.permute(0, 2, 1) # transpose + if groupsize != k: + w = w.reshape((batch, -1, groupsize, n)) + w = w.permute(0, 2, 1, 3) + w = w.reshape((batch, groupsize, -1)) + s = s.reshape((batch, 1, -1)) + + if groupsize != k: + w = w.reshape((batch, groupsize, -1, n)) + w = w.permute(0, 2, 1, 3) + w = w.reshape((batch, k, n)).contiguous() + s = s.reshape((batch, -1, len(_scale_perm)))[:, :, _scale_perm] + else: + s = s.reshape((batch, -1, len(_scale_perm_single)))[:, :, _scale_perm_single] + s = s.reshape((batch, -1, n)).contiguous() + w = w.reshape((batch, k // tile, tile, n // tile, tile)) + w = w.permute((0, 1, 3, 2, 4)) + w = w.reshape((batch, k // tile, n * tile)) + res = w + res = res.reshape((batch, -1, _perm.numel()))[:, :, _perm].reshape(res.shape) + if qbits == 4: + q = torch.zeros( + (batch, res.shape[1], res.shape[2] // 2), dtype=torch.int8, device=w.device + ) + for i in range(2): + q |= res[:, :, i::2] << 4 * i + q = q.reshape(batch, -1, n).contiguous() + else: + q = res.clone() + q[:, :, 2::8] = res[:, :, 4::8] + q[:, :, 3::8] = res[:, :, 5::8] + q[:, :, 4::8] = res[:, :, 2::8] + q[:, :, 5::8] = res[:, :, 3::8] + q = q.reshape(batch, -1, n).to(torch.int8).contiguous() + + return q, s + + +def repack_8bit_into_32bit(input): + output = torch.zeros( + (input.shape[0], input.shape[1], input.shape[2] // 4), + dtype=torch.int32, + device=input.device, + ) + for i in range(4): + output |= (input[:, :, i::4] & 0xFF).to(torch.int32) << 8 * i + + return output + + +def repack_weights(qweight, scale, weight_bits): + batch_present = len(qweight.shape) == 3 + if not batch_present: + qweight = qweight.unsqueeze(0) + scale = scale.unsqueeze(0) + + if weight_bits == 4: + """Unpack two 4-bit values from each byte. + """ + qweight_unpacked = torch.empty( + (qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2]), + dtype=torch.uint8, + device=qweight.device, + ) + for i in range(2): + qweight_unpacked[:, :, i::2] = ((qweight << 4 * (1 - i)) >> 4).reshape( + qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2] // 2 + ) + else: + qweight_unpacked = qweight + + qweight_packed, scale_packed = pack_for_marlin(qweight_unpacked, scale, weight_bits) + """Marlin kernels expect tensors in int32 format in a certain shape """ - old_weight = getattr(layer, param_name) - assert isinstance(old_weight, RTNParameter) - data = old_weight.data.data + qweight_repacked = repack_8bit_into_32bit(qweight_packed.to(torch.uint8)) + qweight_reshaped = qweight_repacked.reshape( + qweight.shape[0], qweight.shape[2] // 16, -1 + ) + if not batch_present: + qweight_reshaped = qweight_reshaped.squeeze(0) + scale_packed = scale_packed.squeeze(0) + + return qweight_reshaped, scale_packed - delattr(layer, param_name) - if reshape: - data = data.reshape(old_weight.shape[0], old_weight.shape[1] * 2, -1) - new_weight = Parameter(data=data, requires_grad=False) - layer.register_parameter(param_name, new_weight) +def init_workspace(device): + global workspace + if workspace is None: + workspace = marlin_make_workspace_new(device, 4) diff --git a/vllm/model_executor/layers/quantization/schema.py b/vllm/model_executor/layers/quantization/schema.py index 9396da0ecd1a..669bd9d6ed83 100644 --- a/vllm/model_executor/layers/quantization/schema.py +++ b/vllm/model_executor/layers/quantization/schema.py @@ -13,8 +13,6 @@ scaling factors. """ -from typing import Optional - from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator @@ -75,7 +73,7 @@ class QuantParamSchema(BaseModel): # TODO: Generalize and extend with more fields # (e.g. weights/activations params) once functionality is enabled model_config = ConfigDict(protected_namespaces=()) - model_type: Optional[str] + model_type: str | None kv_cache: KVCacheQuantSchema @model_validator(mode="after") diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index 55eb2890bb2f..f42c45dae76d 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -5,6 +5,7 @@ from importlib.util import find_spec from typing import Any, Optional +import regex as re import torch import torch.nn.functional as F from packaging import version @@ -62,7 +63,7 @@ class TorchAOConfig(QuantizationConfig): def __init__( self, torchao_config, - skip_modules: Optional[list[str]] = None, + skip_modules: list[str] | None = None, is_checkpoint_torchao_serialized: bool = False, ) -> None: """ @@ -192,9 +193,26 @@ def get_quant_method( module_fqn = prefix if isinstance(self.torchao_config, ModuleFqnToConfig): module_fqn_to_config = self.torchao_config.module_fqn_to_config - c = module_fqn_to_config.get(module_fqn) or module_fqn_to_config.get( - "_default", None - ) + c = None + if module_fqn in module_fqn_to_config: + assert not module_fqn.startswith("re:"), ( + "module fqn should not start with" + "`re:`, which is used for specifying regex" + ) + c = module_fqn_to_config[module_fqn] + else: + for maybe_module_fqn_pattern in module_fqn_to_config: + if not maybe_module_fqn_pattern.startswith("re:"): + continue + elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn): + # we'll apply the config for first fully matched pattern + c = module_fqn_to_config[maybe_module_fqn_pattern] + break + else: + # fallback to use default if no module specific + # config is provided + c = module_fqn_to_config.get("_default", None) + if c is not None: current_torchao_config = TorchAOConfig( c, self.skip_modules, self.is_checkpoint_torchao_serialized @@ -283,7 +301,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: return F.linear(x, layer.weight, bias) diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py index a24cd41659a0..64bfa8fb80eb 100644 --- a/vllm/model_executor/layers/quantization/tpu_int8.py +++ b/vllm/model_executor/layers/quantization/tpu_int8.py @@ -119,7 +119,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: try: import torch_xla.experimental.custom_kernel # noqa: F401 diff --git a/vllm/model_executor/layers/quantization/utils/bitblas_utils.py b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py index 4b7a22a26653..62a4f9036688 100644 --- a/vllm/model_executor/layers/quantization/utils/bitblas_utils.py +++ b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch from packaging import version @@ -29,7 +28,7 @@ # Determines the supported quantization types for BitBLAS based on the # device's capability and whether zero-point (zp) is used. def query_bitblas_supported_quant_types( - has_zp: bool, device_capability: Optional[int] = None + has_zp: bool, device_capability: int | None = None ): if device_capability is None: capability_tuple = current_platform.get_device_capability() @@ -52,10 +51,10 @@ def query_bitblas_supported_quant_types( def _check_bitblas_supported( quant_type: ScalarType, - group_size: Optional[int], + group_size: int | None, has_zp: bool, - device_capability: Optional[int] = None, -) -> tuple[bool, Optional[str]]: + device_capability: int | None = None, +) -> tuple[bool, str | None]: if device_capability is None: capability_tuple = current_platform.get_device_capability() device_capability = ( @@ -99,7 +98,7 @@ def check_bitblas_supported( quant_type: ScalarType, group_size: int, has_zp: bool = False, - device_capability: Optional[int] = None, + device_capability: int | None = None, ) -> bool: cond, _ = _check_bitblas_supported( quant_type, group_size, has_zp, device_capability @@ -156,7 +155,7 @@ def check_bitblas_supports_shape( input_size_per_partition: int, input_size: int, group_size: int, -) -> tuple[bool, Optional[str]]: +) -> tuple[bool, str | None]: try: verify_bitblas_supports_shape( output_size_per_partition, input_size_per_partition, input_size, group_size diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 7059a029ba67..b3a4cb2de139 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -2,8 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utility helpers for NVFP4 + FlashInfer fused-MoE path""" -from __future__ import annotations - import torch import vllm.envs as envs @@ -29,12 +27,12 @@ def is_flashinfer_fp4_cutlass_moe_available() -> bool: - """Return ``True`` when FlashInfer CUTLASS NV-FP4 kernels can be used.""" + """Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used.""" return ( envs.VLLM_USE_FLASHINFER_MOE_FP4 and has_flashinfer_cutlass_fused_moe() and current_platform.is_cuda() - and current_platform.is_device_capability(100) + and current_platform.has_device_capability(100) ) @@ -60,7 +58,7 @@ def build_flashinfer_fp4_cutlass_moe_prepare_finalize( ) -> mk.FusedMoEPrepareAndFinalize: """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" use_dp = moe.moe_parallel_config.dp_size > 1 - enable_alltoallv = envs.VLLM_ALL2ALL_BACKEND == "flashinfer_all2allv" + enable_alltoallv = moe.moe_parallel_config.all2all_backend == "flashinfer_all2allv" return create_flashinfer_prepare_finalize( use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 7f32ef00647c..50ea049c3d5a 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import Enum -from typing import Optional import torch @@ -101,10 +100,10 @@ def apply_flashinfer_per_tensor_scale_fp8( layer: torch.nn.Module, hidden_states: torch.Tensor, router_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], + routing_bias: torch.Tensor | None, top_k: int, - num_expert_group: Optional[int], - topk_group: Optional[int], + num_expert_group: int | None, + topk_group: int | None, global_num_experts: int, apply_router_weight_on_input: bool, ) -> torch.Tensor: @@ -186,7 +185,7 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None: def build_flashinfer_fp8_cutlass_moe_prepare_finalize( - moe: Optional[FusedMoEConfig], + moe: FusedMoEConfig | None, ) -> mk.FusedMoEPrepareAndFinalize: """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False @@ -194,9 +193,9 @@ def build_flashinfer_fp8_cutlass_moe_prepare_finalize( def select_cutlass_fp8_gemm_impl( - moe: Optional[FusedMoEConfig], + moe: FusedMoEConfig | None, quant_config: FusedMoEQuantConfig, - out_dtype: Optional[torch.dtype] = None, + out_dtype: torch.dtype | None = None, ) -> mk.FusedMoEPermuteExpertsUnpermute: """Return a GEMM *experts* implementation for fused-MoE layers""" @@ -225,7 +224,7 @@ def flashinfer_cutlass_moe_fp8( inplace: bool = False, activation: str = "silu", global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, ) -> torch.Tensor: quant_config = layer.quant_method.get_fused_moe_quant_config(layer) @@ -264,3 +263,9 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend: f"Unknown flashinfer moe backend: {flashinfer_moe_backend}" f" expected one of {allowed_backends}" ) + + +def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) -> bool: + # TODO(shuw@nvidia): Update when new backends are added. + backends_supporting_global_sf = (FlashinferMoeBackend.CUTLASS,) + return backend in backends_supporting_global_sf diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 16ede6113a94..f25148abb619 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -5,8 +5,8 @@ import functools import json import os -from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Sequence +from typing import Any import torch @@ -28,18 +28,18 @@ ) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op from vllm.utils.deep_gemm import ( fp8_gemm_nt, is_deep_gemm_e8m0_used, is_deep_gemm_supported, should_use_deepgemm_for_fp8_linear, ) +from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) -def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: +def is_fp8(x: torch.dtype | torch.Tensor) -> bool: if isinstance(x, torch.Tensor): x = x.dtype return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz @@ -54,7 +54,7 @@ def cutlass_scaled_mm( Bs: torch.Tensor, block_size: list[int], output_dtype: torch.dtype = torch.float16, - is_hopper: Optional[bool] = None, + is_hopper: bool | None = None, ) -> torch.Tensor: if is_hopper is None: is_hopper = current_platform.is_device_capability(90) @@ -279,8 +279,8 @@ def apply( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, + input_scale: torch.Tensor | None = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: assert input_scale is None # View input as 2D matrix for fp8 methods @@ -360,7 +360,7 @@ def _run_aiter( weight, input_scale, weight_scale, - self.weight_group_shape, + list(self.weight_group_shape), input_2d.dtype, ) @@ -377,7 +377,7 @@ def _run_triton( weight, input_scale, weight_scale, - self.weight_group_shape, + list(self.weight_group_shape), input_2d.dtype, ) @@ -394,7 +394,7 @@ def _dispatch_w8a8_blockscale_op( ], torch.Tensor, ], - Optional[QuantFP8], + QuantFP8 | None, ]: if use_cutlass: return self._run_cutlass, ( @@ -418,7 +418,7 @@ def _dispatch_w8a8_blockscale_op( def input_to_float8( - x: torch.Tensor, dtype: Optional[torch.dtype] = None + x: torch.Tensor, dtype: torch.dtype | None = None ) -> tuple[torch.Tensor, torch.Tensor]: """This function quantizes input values to float8 values " "with tensor-wise quantization.""" @@ -568,10 +568,10 @@ def per_token_group_quant_fp8( x: torch.Tensor, group_size: int, eps: float = 1e-10, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype | None = None, column_major_scales: bool = False, - out_q: Optional[torch.Tensor] = None, - use_ue8m0: Optional[bool] = None, + out_q: torch.Tensor | None = None, + use_ue8m0: bool | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Function to perform per-token-group quantization on an input tensor `x`. It converts the tensor values into signed float8 values and returns the @@ -754,7 +754,7 @@ def _w8a8_triton_block_scaled_mm( @functools.lru_cache def get_w8a8_block_fp8_configs( N: int, K: int, block_n: int, block_k: int -) -> Optional[dict[int, Any]]: +) -> dict[int, Any] | None: """ Return optimized configurations for the w8a8 block fp8 kernel. The return value will be a dictionary that maps an irregular grid of @@ -887,11 +887,11 @@ def requant_weight_ue8m0_inplace( UE8M0 (power-of-two) format expected by the new DeepGEMM kernels inplace. Args: - weight: Block-quantised weight tensor stored in ``torch.float8_e4m3fn``. - Expected shape ``(..., M, K)``. - weight_scale: Corresponding per-block scale tensor (``torch.float32``) - with shape ``(..., M // block_size[0], K // block_size[1])``. - block_size: 2-element iterable ``[block_m, block_k]`` describing the + weight: Block-quantised weight tensor stored in `torch.float8_e4m3fn`. + Expected shape `(..., M, K)`. + weight_scale: Corresponding per-block scale tensor (`torch.float32`) + with shape `(..., M // block_size[0], K // block_size[1])`. + block_size: 2-element iterable `[block_m, block_k]` describing the block quantisation granularity. """ if weight.numel() == 0: @@ -1012,7 +1012,7 @@ def validate_fp8_block_shape( def create_fp8_weight_parameter( output_size_per_partition: int, input_size_per_partition: int, - weight_loader: Optional[Callable], + weight_loader: Callable | None, ) -> torch.nn.Parameter: """Create FP8 weight parameter.""" from vllm.model_executor.parameter import ModelWeightParameter @@ -1033,8 +1033,8 @@ def create_fp8_scale_parameter( parameter_type: torch.nn.Parameter, output_partition_sizes: list[int], input_size_per_partition: int, - block_size: Optional[list[int]], - weight_loader: Optional[Callable], + block_size: list[int] | None, + weight_loader: Callable | None, ) -> torch.nn.Parameter: """Create scale parameter based on quantization strategy.""" if parameter_type == ChannelQuantScaleParameter: @@ -1070,7 +1070,7 @@ def create_fp8_scale_parameter( def create_fp8_input_scale( - output_partition_sizes: list[int], weight_loader: Optional[Callable] + output_partition_sizes: list[int], weight_loader: Callable | None ) -> torch.nn.Parameter: """Create input scale parameter for static activation quantization.""" from vllm.model_executor.parameter import PerTensorScaleParameter @@ -1087,8 +1087,8 @@ def process_fp8_weight_tensor_strategy( weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: list[int], - input_scale: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + input_scale: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: """Process weights for tensor-wise quantization strategy.""" from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( normalize_e4m3fn_to_e4m3fnuz, @@ -1114,8 +1114,8 @@ def process_fp8_weight_tensor_strategy( def process_fp8_weight_channel_strategy( weight: torch.Tensor, weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + input_scale: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: """Process weights for channel-wise quantization strategy.""" from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( normalize_e4m3fn_to_e4m3fnuz, diff --git a/vllm/model_executor/layers/quantization/utils/gptq_utils.py b/vllm/model_executor/layers/quantization/utils/gptq_utils.py index 6209dda955ce..dfebeca93392 100644 --- a/vllm/model_executor/layers/quantization/utils/gptq_utils.py +++ b/vllm/model_executor/layers/quantization/utils/gptq_utils.py @@ -4,7 +4,7 @@ from copy import deepcopy from fractions import Fraction from types import MappingProxyType -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING import regex as re import torch @@ -25,7 +25,7 @@ # Match dynamic rules with module name (prefix) and override quantize # config if module (prefix) matches a rule -def override_config(config: Union[GPTQConfig, GPTQMarlinConfig], prefix: str): +def override_config(config: GPTQConfig | GPTQMarlinConfig, prefix: str): weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits) if isinstance(weight_bits, int): config.weight_bits = weight_bits @@ -60,11 +60,11 @@ def override_config(config: Union[GPTQConfig, GPTQMarlinConfig], prefix: str): def get_dynamic_override( - config: Union[GPTQConfig, GPTQMarlinConfig], + config: GPTQConfig | GPTQMarlinConfig, layer_name: str, - key: Optional[str] = None, - default_value: Union[int, bool, None] = None, -) -> Union[dict, int, bool, None]: + key: str | None = None, + default_value: int | bool | None = None, +) -> dict | int | bool | None: for pattern, pattern_dict in config.dynamic.items(): # Negative match: matched modules are excluded from quantized init if pattern.startswith("-:"): @@ -126,7 +126,7 @@ def is_layer_gptq_quantized( def get_linear_quant_method( - config: Union[GPTQConfig, GPTQMarlinConfig], + config: GPTQConfig | GPTQMarlinConfig, layer: torch.nn.Module, prefix: str, linear_method_cls: type, diff --git a/vllm/model_executor/layers/quantization/utils/int8_utils.py b/vllm/model_executor/layers/quantization/utils/int8_utils.py index 1b8efe4332c5..925d0a516ce6 100644 --- a/vllm/model_executor/layers/quantization/utils/int8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/int8_utils.py @@ -6,7 +6,7 @@ import json import logging import os -from typing import Any, Optional +from typing import Any import torch @@ -21,8 +21,8 @@ def apply_w8a8_block_int8_linear( weight: torch.Tensor, block_size: list[int], weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, + input_scale: torch.Tensor | None = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: assert input_scale is None # View input as 2D matrix for fp8 methods @@ -359,7 +359,7 @@ def _w8a8_block_int8_matmul( @functools.lru_cache def get_w8a8_block_int8_configs( N: int, K: int, block_n: int, block_k: int -) -> Optional[dict[int, Any]]: +) -> dict[int, Any] | None: """ Return optimized configurations for the w8a8 block fp8 kernel. diff --git a/vllm/model_executor/layers/quantization/utils/layer_utils.py b/vllm/model_executor/layers/quantization/utils/layer_utils.py index 4bf31340a2f6..3b8c9a8b6ca1 100644 --- a/vllm/model_executor/layers/quantization/utils/layer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/layer_utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Union import torch @@ -21,7 +20,7 @@ def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor): # Newly generated tensors need to replace existing tensors that are # already registered as parameters by vLLM (and won't be freed) def replace_parameter( - mod: torch.nn.Module, name: str, new: Union[torch.Tensor, torch.nn.Parameter] + mod: torch.nn.Module, name: str, new: torch.Tensor | torch.nn.Parameter ) -> None: old = getattr(mod, name) if ( diff --git a/vllm/model_executor/layers/quantization/utils/machete_utils.py b/vllm/model_executor/layers/quantization/utils/machete_utils.py index 69466bdcb64c..ccfcdac1ec0f 100644 --- a/vllm/model_executor/layers/quantization/utils/machete_utils.py +++ b/vllm/model_executor/layers/quantization/utils/machete_utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -41,7 +40,7 @@ def query_machete_supported_group_sizes(act_type: torch.dtype) -> list[int]: def check_machete_supports_shape( in_features: int, out_featrues: int -) -> tuple[bool, Optional[str]]: +) -> tuple[bool, str | None]: if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: return ( False, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index d2fa5af1b854..071fb4ba1686 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import numpy import torch @@ -34,9 +33,9 @@ # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl def query_marlin_supported_quant_types( - has_zp: Optional[bool] = None, + has_zp: bool | None = None, include_fp_type: bool = True, - device_capability: Optional[int] = None, + device_capability: int | None = None, ): if device_capability is None: capability_tuple = current_platform.get_device_capability() @@ -72,10 +71,10 @@ def query_marlin_supported_quant_types( def _check_marlin_supported( quant_type: ScalarType, - group_size: Optional[int], + group_size: int | None, has_zp: bool, - device_capability: Optional[int] = None, -) -> tuple[bool, Optional[str]]: + device_capability: int | None = None, +) -> tuple[bool, str | None]: if device_capability is None: capability_tuple = current_platform.get_device_capability() device_capability = ( @@ -109,7 +108,7 @@ def check_marlin_supported( quant_type: ScalarType, group_size: int, has_zp: bool = False, - device_capability: Optional[int] = None, + device_capability: int | None = None, ) -> bool: cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) return cond @@ -164,7 +163,7 @@ def check_marlin_supports_shape( input_size_per_partition: int, input_size: int, group_size: int, -) -> tuple[bool, Optional[str]]: +) -> tuple[bool, str | None]: try: verify_marlin_supports_shape( output_size_per_partition, input_size_per_partition, input_size, group_size @@ -445,7 +444,7 @@ def apply_gptq_marlin_linear( output_size_per_partition: int, input_size_per_partition: int, is_k_full: bool, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, ) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) @@ -494,7 +493,7 @@ def apply_awq_marlin_linear( quant_type: ScalarType, output_size_per_partition: int, input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, ) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) @@ -529,3 +528,48 @@ def apply_awq_marlin_linear( ) return output.reshape(out_shape) + + +def apply_rtn_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: torch.Tensor | None = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype, + ) + + output = ops.gptq_marlin_gemm( + reshaped_x, + None, + weight, + bias, + weight_scale, + None, + None, + None, + None, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index c5e34f392fb2..842fb9b62267 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -95,11 +94,11 @@ def apply_fp4_marlin_linear( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, - weight_scale_2: Optional[torch.Tensor], + weight_scale_2: torch.Tensor | None, workspace: torch.Tensor, size_n: int, size_k: int, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, ) -> torch.Tensor: # For GPUs that lack FP4 hardware support, we can leverage the diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 9348ac158daa..8c96848a8539 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -45,7 +44,7 @@ def apply_fp8_marlin_linear( workspace: torch.Tensor, size_n: int, size_k: int, - bias: Optional[torch.Tensor], + bias: torch.Tensor | None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, ) -> torch.Tensor: # For GPUs that lack FP8 hardware support, we can leverage the diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py index 1bbd88d5ca71..89756c45ef55 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -2,8 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utility functions used for tests and benchmarks""" -from typing import Optional - import numpy as np import torch @@ -100,7 +98,7 @@ def marlin_quantize( quant_type: ScalarType, group_size: int, act_order: bool, - test_perm: Optional[torch.Tensor] = None, + test_perm: torch.Tensor | None = None, ): size_k, size_n = w.shape num_bits = quant_type.size_bits diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index ee6c826f8b2c..5e87cadfb107 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any import torch from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer logger = init_logger(__name__) @@ -71,17 +72,17 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): def _can_support_mxfp4( use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - e_score_correction_bias: Optional[torch.Tensor] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, scoring_func: str = "softmax", activation: str = "swigluoai", - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, ): return not ( use_grouped_topk diff --git a/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py index 2249e9658970..2b5659e30097 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py @@ -3,7 +3,7 @@ import torch from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def _quant_dequant_mxfp6( diff --git a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py index 248b2d6c4af2..bed771fd1c4d 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py @@ -18,4 +18,7 @@ def mxfp8_e4m3_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: "`pip install flashinfer`" ) from err - return mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False) + x_q, x_scales = mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False) + if x_scales.ndim == 1: + x_scales = x_scales.view(x.size(0), -1) + return x_q, x_scales diff --git a/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py b/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py index 3c71441a3df7..7752324f41fe 100644 --- a/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py +++ b/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import Enum -from typing import Union from vllm.logger import init_logger @@ -28,9 +27,7 @@ class OCP_MX_Scheme(str, Enum): w_mxfp6_e2m3_a_mxfp6_e2m3 = "w_mxfp6_e2m3_a_mxfp6_e2m3" @classmethod - def from_quant_dtype( - cls, input_dtype: Union[str, None], weight_dtype: Union[str, None] - ): + def from_quant_dtype(cls, input_dtype: str | None, weight_dtype: str | None): if input_dtype not in OCP_MX_DTYPES or weight_dtype not in OCP_MX_DTYPES: return None elif input_dtype == "mxfp4" and weight_dtype == "mxfp4": diff --git a/vllm/model_executor/layers/quantization/utils/petit_utils.py b/vllm/model_executor/layers/quantization/utils/petit_utils.py index 1f053103fc3c..081f53eac939 100644 --- a/vllm/model_executor/layers/quantization/utils/petit_utils.py +++ b/vllm/model_executor/layers/quantization/utils/petit_utils.py @@ -43,8 +43,8 @@ def _import_petit_kernel() -> "ModuleType": def _check_petit_nvfp4_supported( - quant_method: str, group_size: Optional[int] -) -> tuple[bool, Optional[str]]: + quant_method: str, group_size: int | None +) -> tuple[bool, str | None]: if quant_method != "NVFP4": return ( False, @@ -62,7 +62,7 @@ def _check_petit_nvfp4_supported( return (True, None) -def verify_petit_nvfp4_supported(quant_method: str, group_size: Optional[int]) -> None: +def verify_petit_nvfp4_supported(quant_method: str, group_size: int | None) -> None: supported, error_msg = _check_petit_nvfp4_supported(quant_method, group_size) if not supported: assert error_msg is not None @@ -98,7 +98,7 @@ def apply_petit_nvfp4_linear( weight_scale_2: torch.Tensor, size_n: int, size_k: int, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: # Trigger (or get) the import here as well. petit_kernel = _import_petit_kernel() diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 2e9b279465f9..d056d3404385 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -5,7 +5,7 @@ from collections.abc import Mapping from dataclasses import dataclass from types import MappingProxyType -from typing import ClassVar, NamedTuple, Optional +from typing import ClassVar, NamedTuple import numpy import torch @@ -91,7 +91,7 @@ class QuantKey: dtype: torch.dtype scale: ScaleDesc - scale2: Optional[ScaleDesc] = None + scale2: ScaleDesc | None = None symmetric: bool = True def __str__(self): @@ -205,7 +205,7 @@ def scaled_quantize( def scaled_dequantize( x_q: torch.Tensor, x_s: torch.Tensor, - group_shape: Optional[GroupShape] = None, + group_shape: GroupShape | None = None, out_dtype: torch.dtype = torch.float32, ) -> tuple[torch.Tensor, torch.Tensor]: if group_shape is not None: @@ -285,7 +285,18 @@ def is_layer_skipped( prefix: str, ignored_layers: list[str], fused_mapping: Mapping[str, list[str]] = MappingProxyType({}), + *, + skip_with_substr: bool = False, ) -> bool: + def prefix_full_match(prefix: str, ignored_layers: list[str]) -> bool: + return prefix in ignored_layers + + # For case like: ignored_layers = ["self_attn"] + def substr_match(prefix: str, ignored_layers: list[str]) -> bool: + return any(layer in prefix for layer in ignored_layers) + + match_func = substr_match if skip_with_substr else prefix_full_match + # prefix: model.layers.0.self_attn.q_proj # proj_name: q_proj proj_name = prefix.split(".")[-1] @@ -302,7 +313,7 @@ def is_layer_skipped( is_skipped = None for shard_prefix in shard_prefixes: - is_shard_skipped = shard_prefix in ignored_layers + is_shard_skipped = match_func(shard_prefix, ignored_layers) if is_skipped is None: is_skipped = is_shard_skipped @@ -312,16 +323,16 @@ def is_layer_skipped( "are quantized. All shards of fused layers " "to have the same precision." ) - elif "experts" in prefix: + elif "experts" in prefix and not skip_with_substr: + expert_ignore_layers = filter( + lambda layer_name: "experts" in layer_name, ignored_layers + ) return any( - [ - prefix in layer_name - for layer_name in ignored_layers - if "experts" in layer_name - ] + prefix in layer_name if not skip_with_substr else layer_name in prefix + for layer_name in expert_ignore_layers ) else: - is_skipped = prefix in ignored_layers + is_skipped = match_func(prefix, ignored_layers) assert is_skipped is not None return is_skipped @@ -336,7 +347,7 @@ def permute_rows( q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int, - test_perm: Optional[torch.Tensor] = None, + test_perm: torch.Tensor | None = None, ): assert q_w.shape == w_ref.shape @@ -365,7 +376,7 @@ def permute_rows( def quantize_weights( w: torch.Tensor, quant_type: ScalarType, - group_size: Optional[int], + group_size: int | None, zero_points: bool = False, ref_zero_points_after_scales: bool = False, ): @@ -466,7 +477,7 @@ def gptq_quantize_weights( quant_type: ScalarType, group_size: int, act_order: bool, - test_perm: Optional[torch.Tensor] = None, + test_perm: torch.Tensor | None = None, ): size_k, _ = w.shape diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index c26cd4f28cb6..380431e86435 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -1,19 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, Optional, Union +from collections.abc import Callable import torch from packaging import version from vllm import _custom_ops as ops from vllm import envs -from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.config import CompilationMode, get_current_vllm_config from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer +from vllm.utils.torch_utils import direct_register_custom_op # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale @@ -75,7 +75,7 @@ def cutlass_group_gemm_supported() -> bool: def per_tensor_dequantize( - tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor] + tensor: torch.Tensor, inv_scale: float | torch.Tensor ) -> torch.Tensor: fake_qweight = tensor.to(torch.float16) dq_weight = fake_qweight * inv_scale @@ -399,7 +399,7 @@ def __init__( self, act_quant_static: bool, act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, - pad_output: Optional[bool] = None, + pad_output: bool | None = None, ): if current_platform.is_rocm(): self.preferred_backend = "rocm" @@ -419,7 +419,7 @@ def __init__( if pad_output is None: config = get_current_vllm_config().compilation_config pad_output = ( - config.level < CompilationLevel.PIECEWISE + config.mode < CompilationMode.VLLM_COMPILE and self.preferred_backend == "torch" ) @@ -437,10 +437,10 @@ def apply( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, - out_dtype: Optional[torch.dtype] = None, - input_scale: Optional[torch.Tensor] = None, - input_scale_ub: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, + out_dtype: torch.dtype | None = None, + input_scale: torch.Tensor | None = None, + input_scale_ub: torch.Tensor | None = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.input_scale is None and x_scale computed from x. @@ -464,8 +464,16 @@ def apply( else: qinput, x_scale = input_2d, input_scale - per_tensor_weights = weight_scale.numel() == 1 - per_tensor_activations = x_scale.numel() == 1 + # Must have dim() conditions + # In per-token quant scenario, when the number of token is 1, + # the scale will only have 1 elements. + # Without checking the dim(), + # we cannot distingushes between per-tensor and per-token quant. + # Example: + # When the number of token is 1, per-token scale is [[1]] + # When per-tensor scale is [1] or (). + per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2 + per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2 # TODO(luka) do this dispatch during init (after ScaledMM refactor) w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( @@ -486,8 +494,8 @@ def apply( def normalize_e4m3fn_to_e4m3fnuz( weight: torch.Tensor, weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + input_scale: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: assert weight.dtype == torch.float8_e4m3fn # The bits pattern 10000000(-128) represents zero in e4m3fn # but NaN in e4m3fnuz. So here we set it to 0. diff --git a/vllm/model_executor/layers/resampler.py b/vllm/model_executor/layers/resampler.py index 6ae2db0f428c..c9fa8054625e 100644 --- a/vllm/model_executor/layers/resampler.py +++ b/vllm/model_executor/layers/resampler.py @@ -34,8 +34,8 @@ """ import math +from collections.abc import Callable from functools import partial -from typing import Callable, Optional, Union import numpy as np import torch @@ -48,9 +48,7 @@ DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) -def get_abs_pos( - abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, int] -) -> torch.Tensor: +def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor | int) -> torch.Tensor: # abs_pos: L, C # tgt_size: (H, W) # return: M, C @@ -124,7 +122,7 @@ def get_2d_sincos_pos_embed_from_grid( def get_2d_sincos_pos_embed( embed_dim: int, - grid_size: Union[int, tuple[int, int]], + grid_size: int | tuple[int, int], cls_token: bool = False, version: tuple[int, int] = (2, 0), ) -> torch.Tensor: @@ -168,10 +166,10 @@ def __init__( num_queries: int, embed_dim: int, num_heads: int, - kv_dim: Optional[int] = None, + kv_dim: int | None = None, norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, do_post_projection: bool = True, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -200,12 +198,10 @@ def __init__( self.ln_q = norm_layer(embed_dim) self.ln_kv = norm_layer(embed_dim) self.do_post_projection = do_post_projection - self.ln_post = norm_layer(embed_dim) if do_post_projection else None - self.proj = ( - nn.Parameter((embed_dim**-0.5) * torch.empty(embed_dim, embed_dim)) - if do_post_projection - else None - ) + if self.do_post_projection: + self.ln_post = norm_layer(embed_dim) + data = (embed_dim**-0.5) * torch.empty(embed_dim, embed_dim) + self.proj = nn.Parameter(data=data) def _repeat(self, query, N: int): return query.unsqueeze(1).repeat(1, N, 1) @@ -224,11 +220,11 @@ def __init__( grid_size: int, embed_dim: int, num_heads: int, - kv_dim: Optional[int] = None, + kv_dim: int | None = None, norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, adaptive: bool = False, do_post_projection: bool = True, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__( @@ -252,8 +248,8 @@ def __init__( def forward( self, x: torch.Tensor, - tgt_sizes: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: torch.Tensor | None = None, + attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: if tgt_sizes is None: tgt_sizes = int(math.sqrt(x.size(1))) diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index e6956de4bfaa..64187c97cab7 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Rotary Positional Embeddings.""" -from typing import Any, Optional +from typing import Any import torch @@ -28,10 +28,10 @@ def get_rope( max_position: int, base: float, is_neox_style: bool = True, - rope_scaling: Optional[dict[str, Any]] = None, - dtype: Optional[torch.dtype] = None, + rope_scaling: dict[str, Any] | None = None, + dtype: torch.dtype | None = None, partial_rotary_factor: float = 1.0, - dual_chunk_attention_config: Optional[dict[str, Any]] = None, + dual_chunk_attention_config: dict[str, Any] | None = None, ) -> RotaryEmbedding: if dtype is None: dtype = torch.get_default_dtype() diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index cf50b60118b9..711902f0cc67 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -2,8 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Rotary Positional Embeddings Base Class.""" -from typing import Optional - import torch from vllm.model_executor.custom_op import CustomOp @@ -92,8 +90,8 @@ def forward_native( self, positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + key: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: """A PyTorch-native implementation of forward().""" positions = positions.flatten() num_tokens = positions.shape[0] @@ -121,8 +119,8 @@ def forward_cuda( self, positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + key: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: if self.use_flashinfer: torch.ops.vllm.flashinfer_rotary_embedding( positions, @@ -154,8 +152,8 @@ def forward_hip( self, positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + key: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: if self.is_rocm_triton_rotary_embedding_enabled: self._match_cos_sin_cache_dtype(query) rocm_aiter_rotary_emb( @@ -167,18 +165,15 @@ def forward_hip( self.rotary_dim, self.is_neox_style, ) - else: - # ops.rotary_embedding() is an in-place operation - # that updates the query and key tensors. - self.forward_cuda(positions, query, key) - return query, key + return query, key + return self.forward_cuda(positions, query, key) def forward_xpu( self, positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + key: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: from vllm._ipex_ops import ipex_ops as ops self._match_cos_sin_cache_dtype(query) diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 124ea0236cbf..9e6ec9fdd523 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -2,15 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math +from collections.abc import Callable from functools import cache from importlib.util import find_spec -from typing import Callable, Optional import torch from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op if current_platform.is_cuda(): from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb @@ -72,7 +72,7 @@ def apply_rotary_emb_dispatch( @cache def dispatch_rotary_emb_function( - default: Optional[Callable[..., torch.Tensor]] = None, + default: Callable[..., torch.Tensor] | None = None, ) -> Callable[..., torch.Tensor]: if current_platform.is_cuda(): return apply_rotary_emb diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index eaedca9b5219..2e5efec06663 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math -from typing import Optional import torch @@ -110,9 +109,9 @@ def forward_native( self, positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: """PyTorch-native implementation equivalent to forward().""" assert key is not None self._match_cos_sin_cache_dtype(query) @@ -151,7 +150,7 @@ def forward_cuda( self, positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: return self.forward_native(positions, query, key, offsets) diff --git a/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py b/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py index 0e6eddda772f..b5dd94cc7f53 100644 --- a/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -121,7 +120,7 @@ def forward_native( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, + offsets: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: query = query.view(*query.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size) @@ -185,7 +184,7 @@ def forward_cuda( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, + offsets: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: return self.forward_native(positions, query, key, offsets) diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py index 2bc0477c5af2..749cdbe88a62 100644 --- a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -16,8 +15,8 @@ def forward_native( # type: ignore[override] self, positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + key: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert positions.ndim == 1 or positions.ndim == 2 assert key is not None @@ -71,6 +70,6 @@ def forward_cuda( # type: ignore[override] self, positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + key: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: return self.forward_native(positions, query, key) diff --git a/vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py index cbb3ee4e9974..bb51dcf1c6f5 100644 --- a/vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Union # Adapted from # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py @@ -65,7 +64,7 @@ def __init__( max_position_embeddings: int, base: float, is_neox_style: bool, - scaling_factors: Union[list[float], float], + scaling_factors: list[float] | float, dtype: torch.dtype, ) -> None: if isinstance(scaling_factors, float): diff --git a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py index 0b808e31c903..6241cb5abbc8 100644 --- a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math -from typing import Optional import torch @@ -56,8 +55,8 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: def forward_native( # type: ignore[override] self, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + key: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert key is not None # self.cos_sin_cache here is complex tensor so we cannot cast into # query's dtype directly with self._match_cos_sin_cache_dtype @@ -76,6 +75,13 @@ def forward_native( # type: ignore[override] def forward_cuda( # type: ignore[override] self, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + key: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + return self.forward_native(query, key) + + def forward_hip( # type: ignore[override] + self, + query: torch.Tensor, + key: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: return self.forward_native(query, key) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 120979970679..d269733083d8 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -1,12 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import itertools -from typing import Optional, Union import numpy as np import torch -from transformers import PretrainedConfig from vllm.triton_utils import tl, triton @@ -213,11 +210,11 @@ def __init__( base: float, is_neox_style: bool, dtype: torch.dtype, - mrope_section: Optional[list[int]] = None, + mrope_section: list[int] | None = None, mrope_interleaved: bool = False, # YaRN parameters. *, - scaling_factor: Optional[float] = None, + scaling_factor: float | None = None, extrapolation_factor: float = 1, attn_factor: float = 1, beta_fast: int = 32, @@ -266,9 +263,9 @@ def forward_native( self, positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: """PyTorch-native implementation equivalent to forward(). Args: @@ -319,9 +316,9 @@ def forward_cuda( self, positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert positions.ndim == 1 or positions.ndim == 2 assert key is not None @@ -364,1033 +361,20 @@ def forward_xpu( self, positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: return self.forward_native(positions, query, key, offsets) def forward_cpu( self, positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: return self.forward_native(positions, query, key, offsets) - @classmethod - def get_input_positions( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], - video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], - second_per_grid_ts: Optional[list[float]], - context_len: int = 0, - seq_len: Optional[int] = None, - audio_feature_lengths: Optional[torch.Tensor] = None, - use_audio_in_video: bool = False, - ) -> tuple[list[list[int]], int]: - """Get mrope input positions and delta value.""" - - image_grid_thw = [] if image_grid_thw is None else image_grid_thw - video_grid_thw = [] if video_grid_thw is None else video_grid_thw - second_per_grid_ts = [] if second_per_grid_ts is None else second_per_grid_ts - - llm_positions, mrope_position_delta = cls.get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=context_len, - seq_len=seq_len, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - - return llm_positions.tolist(), mrope_position_delta - - @classmethod - def get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - second_per_grid_ts: list[float], - context_len: int = 0, - seq_len: Optional[int] = None, - audio_feature_lengths: Optional[torch.Tensor] = None, - use_audio_in_video: bool = False, - ) -> tuple[torch.Tensor, int]: - from vllm.transformers_utils.config import thinker_uses_mrope - - if thinker_uses_mrope(hf_config): - return cls._omni_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=context_len, - seq_len=seq_len, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - elif hf_config.model_type in ["glm4v", "glm4v_moe"]: - return cls._glm4v_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - context_len=context_len, - seq_len=seq_len, - ) - elif hf_config.model_type in ["qwen3_vl", "qwen3_vl_moe"]: - return cls._qwen3vl_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - context_len=context_len, - seq_len=seq_len, - ) - elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]: - return cls._ernie_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - context_len=context_len, - seq_len=seq_len, - ) - elif "KeyeVL1_5" in hf_config.model_type: - return cls._keye_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - context_len=context_len, - seq_len=seq_len, - ) - else: - return cls._vl_get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=context_len, - seq_len=seq_len, - ) - - @classmethod - def _glm4v_get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - context_len: int = 0, - seq_len: Optional[int] = None, - ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value for GLM4V.""" - - image_token_id = hf_config.image_token_id - video_start_token_id = hf_config.video_start_token_id - video_end_token_id = hf_config.video_end_token_id - spatial_merge_size = hf_config.vision_config.spatial_merge_size - llm_pos_ids_list: list = [] - - if not (image_grid_thw is None and video_grid_thw is None): - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() - - input_token_type: list[str] = [] - video_check_flg = False - for token in input_tokens: - if token == video_start_token_id: - video_check_flg = True - elif token == video_end_token_id: - video_check_flg = False - - if (token == image_token_id) and (video_check_flg is False): - input_token_type.append("image") - elif (token == image_token_id) and (video_check_flg is True): - input_token_type.append("video") - else: - input_token_type.append("text") - - input_type_group: list[tuple[str, int, int]] = [] - for key, group_iter in itertools.groupby( - enumerate(input_token_type), lambda x: x[1] - ): - group_list = list(group_iter) - start_index = group_list[0][0] - end_index = group_list[-1][0] + 1 - input_type_group.append((key, start_index, end_index)) - - video_frame_num = 1 - mm_data_idx = 0 - for modality_type, start_idx, end_idx in input_type_group: - st_idx = ( - llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - ) - if modality_type == "image": - t, h, w = ( - image_grid_thw[mm_data_idx][0], - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) - llm_grid_t, llm_grid_h, llm_grid_w = ( - t, - h // spatial_merge_size, - w // spatial_merge_size, - ) - - t_index = ( - torch.arange(llm_grid_t) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - .flatten() - ) - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + st_idx - ) - mm_data_idx += 1 - - elif modality_type == "video": - t, h, w = ( - video_frame_num, - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) - llm_grid_t, llm_grid_h, llm_grid_w = ( - t, - h // spatial_merge_size, - w // spatial_merge_size, - ) - - for t_idx in range(llm_grid_t): - t_index = ( - torch.tensor(t_idx) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - .flatten() - ) - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(1, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(1, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + st_idx - ) - - mm_data_idx += 1 - video_frame_num += 1 - - else: - text_len = end_idx - start_idx - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - video_frame_num = 1 - - else: - text_len = len(input_tokens) - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:seq_len] - mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - return llm_positions, mrope_position_delta - - @classmethod - def _qwen3vl_get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - context_len: int = 0, - seq_len: Optional[int] = None, - ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value.""" - - video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw for _ in range(t)] - - image_token_id = hf_config.image_token_id - video_token_id = hf_config.video_token_id - vision_start_token_id = hf_config.vision_start_token_id - spatial_merge_size = hf_config.vision_config.spatial_merge_size - - input_tokens_tensor = torch.tensor(input_tokens) - vision_start_indices = torch.argwhere( - input_tokens_tensor == vision_start_token_id - ).squeeze(1) - vision_tokens = input_tokens_tensor[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() - llm_pos_ids_list: list = [] - - st = 0 - remain_images, remain_videos = image_nums, video_nums - - image_index, video_index = 0, 0 - for _ in range(image_nums + video_nums): - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - video_index += 1 - remain_videos -= 1 - ed = ed_video - - llm_grid_t, llm_grid_h, llm_grid_w = ( - t, - h // spatial_merge_size, - w // spatial_merge_size, - ) - text_len = ed - st - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - - t_index = ( - torch.arange(llm_grid_t) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - .flatten() - ) - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx - ) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] - return llm_positions, mrope_position_delta - - @classmethod - def _ernie_get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - context_len: int = 0, - seq_len: Optional[int] = None, - ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value for Ernie VL.""" - - image_token_id = hf_config.im_patch_id - video_start_token_id = hf_config.video_start_token_id - video_end_token_id = hf_config.video_end_token_id - spatial_conv_size = hf_config.spatial_conv_size - temporal_conv_size = hf_config.temporal_conv_size - llm_pos_ids_list: list = [] - - if not (image_grid_thw is None and video_grid_thw is None): - if isinstance(image_grid_thw, torch.Tensor): - image_grid_thw = image_grid_thw.tolist() - - input_token_type: list[str] = [] - video_check_flg = False - for token in input_tokens: - if token == video_start_token_id: - video_check_flg = True - elif token == video_end_token_id: - video_check_flg = False - - if (token == image_token_id) and (video_check_flg is False): - input_token_type.append("image") - elif (token == image_token_id) and (video_check_flg is True): - input_token_type.append("video") - else: - input_token_type.append("text") - - input_type_group: list[tuple[str, int, int]] = [] - for key, group_iter in itertools.groupby( - enumerate(input_token_type), lambda x: x[1] - ): - group_list = list(group_iter) - start_index = group_list[0][0] - end_index = group_list[-1][0] + 1 - input_type_group.append((key, start_index, end_index)) - - video_frame_num = 1 - mm_data_idx = 0 - for modality_type, start_idx, end_idx in input_type_group: - st_idx = ( - llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - ) - if modality_type == "image": - t, h, w = ( - image_grid_thw[mm_data_idx][0], - image_grid_thw[mm_data_idx][1], - image_grid_thw[mm_data_idx][2], - ) - llm_grid_t, llm_grid_h, llm_grid_w = ( - t, - h // spatial_conv_size, - w // spatial_conv_size, - ) - - t_index = ( - torch.arange(llm_grid_t) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - .flatten() - ) - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + st_idx - ) - mm_data_idx += 1 - - elif modality_type == "video": - t, h, w = ( - video_grid_thw[mm_data_idx][0], - video_grid_thw[mm_data_idx][1], - video_grid_thw[mm_data_idx][2], - ) - llm_grid_t, llm_grid_h, llm_grid_w = ( - t // temporal_conv_size, - h // spatial_conv_size, - w // spatial_conv_size, - ) - - for t_idx in range(llm_grid_t): - t_index = ( - torch.tensor(t_idx) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - .flatten() - ) - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(1, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(1, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + st_idx - ) - - mm_data_idx += 1 - video_frame_num += 1 - - else: - text_len = end_idx - start_idx - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - video_frame_num = 1 - - else: - text_len = len(input_tokens) - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:seq_len] - mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - return llm_positions, mrope_position_delta - - @classmethod - def _keye_get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - context_len: int = 0, - seq_len: Optional[int] = None, - ) -> tuple[torch.Tensor, int]: - if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0: - video_grid_thw = video_grid_thw[0] - """Get mrope input positions and delta value (Keye series).""" - - def split_thw(grid_thw: Union[torch.Tensor, list[int]]) -> list[list[int]]: - """ - Split grid_thw along the t dimension. - - Args: - grid_thw: shape [N, 3] tensor or nested list of [t, h, w]. - - Returns: - List of [1, h, w] rows, repeated t times for each original row. - """ - - if isinstance(grid_thw, list): - grid_thw = torch.tensor(grid_thw, dtype=torch.long) - - if grid_thw.numel() == 0: - return [] - - t, hw = grid_thw[:, 0], grid_thw[:, 1:] - ones = torch.ones_like(hw[:, :1]) # [N,1] - out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0) - return out.tolist() - - video_grid_thw = split_thw(video_grid_thw) - - image_token_id = hf_config.image_token_id - video_token_id = hf_config.video_token_id - spatial_merge_size = hf_config.vision_config.spatial_merge_size - - image_nums = len(image_grid_thw) - frame_nums = len(video_grid_thw) - llm_pos_ids_list: list = [] - - st = 0 - remain_images, remain_frames = image_nums, frame_nums - - image_index, video_index = 0, 0 - for _ in range(image_nums + frame_nums): - if remain_images > 0: - try: - ed_image = input_tokens.index(image_token_id, st) - except ValueError: - ed_image = len(input_tokens) + 1 - else: - ed_image = len(input_tokens) + 1 - if remain_frames > 0: - try: - ed_video = input_tokens.index(video_token_id, st) - except ValueError: - ed_video = len(input_tokens) + 1 - else: - ed_video = len(input_tokens) + 1 - - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - video_index += 1 - remain_frames -= 1 - ed = ed_video - - llm_grid_t, llm_grid_h, llm_grid_w = ( - t, - h // spatial_merge_size, - w // spatial_merge_size, - ) - text_len = ed - st - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - - t_index = ( - ( - torch.arange(llm_grid_t) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - ) - .long() - .flatten() - ) - - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx - ) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] - - return llm_positions, mrope_position_delta - - @classmethod - def _vl_get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - second_per_grid_ts: list[float], - context_len: int = 0, - seq_len: Optional[int] = None, - ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value.""" - - image_token_id = hf_config.image_token_id - video_token_id = hf_config.video_token_id - vision_start_token_id = hf_config.vision_start_token_id - spatial_merge_size = hf_config.vision_config.spatial_merge_size - tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0) - - input_tokens_tensor = torch.tensor(input_tokens) - vision_start_indices = torch.argwhere( - input_tokens_tensor == vision_start_token_id - ).squeeze(1) - vision_tokens = input_tokens_tensor[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() - llm_pos_ids_list: list = [] - - st = 0 - remain_images, remain_videos = image_nums, video_nums - - image_index, video_index = 0, 0 - for _ in range(image_nums + video_nums): - video_second_per_grid_t = 0.0 - if remain_images > 0: - try: - ed_image = input_tokens.index(image_token_id, st) - except ValueError: - ed_image = len(input_tokens) + 1 - else: - ed_image = len(input_tokens) + 1 - if remain_videos > 0: - try: - ed_video = input_tokens.index(video_token_id, st) - except ValueError: - ed_video = len(input_tokens) + 1 - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - video_second_per_grid_t = 1.0 - if second_per_grid_ts: - video_second_per_grid_t = second_per_grid_ts[video_index] - video_index += 1 - remain_videos -= 1 - ed = ed_video - - llm_grid_t, llm_grid_h, llm_grid_w = ( - t, - h // spatial_merge_size, - w // spatial_merge_size, - ) - text_len = ed - st - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - - t_index = ( - ( - torch.arange(llm_grid_t) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - * video_second_per_grid_t - * tokens_per_second - ) - .long() - .flatten() - ) - - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(llm_grid_t, -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(llm_grid_t, llm_grid_h, -1) - .flatten() - ) - llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx - ) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - llm_positions = llm_positions[:, context_len:seq_len] - - return llm_positions, mrope_position_delta - - @classmethod - def _omni_get_input_positions_tensor( - cls, - input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: Union[list[list[int]], torch.Tensor], - video_grid_thw: Union[list[list[int]], torch.Tensor], - second_per_grid_ts: Optional[list[float]] = None, - context_len: int = 0, - seq_len: Optional[int] = None, - audio_feature_lengths: Optional[torch.Tensor] = None, - use_audio_in_video: bool = False, - ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value (Qwen2.5-Omni version). - - Differences from MRotaryEmbedding: - 1. Add audio support (and related `audio_feature_lengths`). - 2. Add `use_audio_in_video` option to read audio from video inputs. - In this case, audio and vision position ids will be split into - chunks and interleaved. - - Example: - - (V_i are vision position ids, A_i are audio position ids) - - |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... - |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... - """ - - # TODO(fyabc): refactor and share more code with - # _vl_get_input_positions_tensor. - - thinker_config = hf_config.thinker_config - audio_token_id = thinker_config.audio_token_index - image_token_id = thinker_config.image_token_index - video_token_id = thinker_config.video_token_index - audio_start_token_id = thinker_config.audio_start_token_id - audio_end_token_id = thinker_config.audio_end_token_id - vision_start_token_id = thinker_config.vision_start_token_id - vision_end_token_id = thinker_config.vision_end_token_id - seconds_per_chunk = thinker_config.seconds_per_chunk - spatial_merge_size = thinker_config.vision_config.spatial_merge_size - tokens_per_second = getattr( - thinker_config.vision_config, "tokens_per_second", 25 - ) - - if isinstance(image_grid_thw, list): - image_grid_thw = torch.tensor(image_grid_thw) - if isinstance(video_grid_thw, list): - video_grid_thw = torch.tensor(video_grid_thw) - - src_item = input_tokens - audio_seqlens = audio_feature_lengths - if not second_per_grid_ts: - second_per_grid_ts = [1] * video_grid_thw.shape[0] - audio_idx = 0 - video_idx = 0 - image_idx = 0 - new_src_item: list[int] = [] - llm_pos_ids_list: list[torch.Tensor] = [] - - idx = 0 - while idx < len(src_item): - new_src_item_len = len(new_src_item) - start_idx = ( - llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - ) - if src_item[idx] not in [audio_token_id, video_token_id, image_token_id]: - if use_audio_in_video and idx > 0: - if ( - src_item[idx] == vision_end_token_id - and src_item[idx - 1] == audio_end_token_id - ): - # processing the <|audio_eos|> before <|vision_eos|> - start_idx -= 1 - elif ( - src_item[idx] == audio_start_token_id - and src_item[idx - 1] == vision_start_token_id - ): - # processing the <|audio_bos|> after <|vision_eos|> - start_idx -= 1 - new_src_item.append(src_item[idx]) - llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1) - llm_pos_ids_list.append(llm_pos_ids) - elif src_item[idx] == audio_token_id: - assert audio_seqlens is not None - audio_seqlen = audio_seqlens[audio_idx] - place_num = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1 - new_src_item.extend([audio_token_id] * place_num) - llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx - llm_pos_ids_list.append(llm_pos_ids) - audio_idx += 1 - elif src_item[idx] == image_token_id: - grid_t = image_grid_thw[image_idx][0] - grid_hs = image_grid_thw[:, 1] - grid_ws = image_grid_thw[:, 2] - t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() - llm_pos_ids = cls._get_llm_pos_ids_for_vision( - start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws - ) - llm_pos_ids_list.append(llm_pos_ids) - vision_seqlen = image_grid_thw[image_idx].prod() // ( - spatial_merge_size**2 - ) - new_src_item.extend([image_token_id] * vision_seqlen) - image_idx += 1 - elif src_item[idx] == video_token_id and not use_audio_in_video: - grid_t = video_grid_thw[video_idx][0] - grid_hs = video_grid_thw[:, 1] - grid_ws = video_grid_thw[:, 2] - t_index = ( - torch.arange(grid_t) - * second_per_grid_ts[video_idx] - * tokens_per_second - ).long() - llm_pos_ids = cls._get_llm_pos_ids_for_vision( - start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws - ) - llm_pos_ids_list.append(llm_pos_ids) - vision_seqlen = video_grid_thw[video_idx].prod() // ( - spatial_merge_size**2 - ) - new_src_item.extend([video_token_id] * vision_seqlen) - video_idx += 1 - else: - # read audio from video - assert audio_seqlens is not None - audio_seqlen = audio_seqlens[audio_idx] - vision_seqlen = video_grid_thw[video_idx].prod() // ( - spatial_merge_size**2 - ) - grid_t = video_grid_thw[video_idx][0] - grid_h = video_grid_thw[video_idx][1] - grid_w = video_grid_thw[video_idx][2] - grid_hs = video_grid_thw[:, 1] - grid_ws = video_grid_thw[:, 2] - t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) - t_index = ( - torch.arange(grid_t) - * second_per_grid_ts[video_idx] - * tokens_per_second - ).long() - t_index_split_chunk = cls._split_list_into_ranges( - t_index, t_ntoken_per_chunk - ) - place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 - pure_audio_len = place_num - 2 - added_audio_len = 0 - audio_llm_pos_ids_list: list[torch.Tensor] = [] - for t_chunk in t_index_split_chunk: - vision_ntoken_per_chunk = ( - len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2) - ) - new_src_item.extend([video_token_id] * vision_ntoken_per_chunk) - vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision( - start_idx, - video_idx, - spatial_merge_size, - t_chunk, - grid_hs, - grid_ws, - ).split(1, dim=1) - llm_pos_ids_list.extend(vision_llm_pos_ids_list) - new_src_item.extend( - min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) - * [audio_token_id] - ) - audio_start_idx = ( - start_idx - if len(audio_llm_pos_ids_list) == 0 - else audio_llm_pos_ids_list[-1][0].item() + 1 - ) - if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0: - audio_llm_pos_ids_list = ( - torch.arange( - min( - t_ntoken_per_chunk, pure_audio_len - added_audio_len - ) - ).expand(3, -1) - + audio_start_idx - ).split(1, dim=1) - else: - audio_llm_pos_ids_list = [] - added_audio_len += min( - t_ntoken_per_chunk, pure_audio_len - added_audio_len - ) - llm_pos_ids_list.extend(audio_llm_pos_ids_list) - if added_audio_len < pure_audio_len: - new_src_item.extend( - (pure_audio_len - added_audio_len) * [audio_token_id] - ) - audio_llm_pos_ids_list = ( - torch.arange(pure_audio_len - added_audio_len).expand(3, -1) - + llm_pos_ids_list[-1].max() - + 1 - ).split(1, dim=1) - llm_pos_ids_list.extend(audio_llm_pos_ids_list) - audio_idx += 1 - video_idx += 1 - # move to the next token - idx += len(new_src_item) - new_src_item_len - - llm_positions = torch.cat(llm_pos_ids_list, dim=1) - mrope_position_delta = ( - torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item) - ) - llm_positions = llm_positions[:, context_len:seq_len] - - return llm_positions, mrope_position_delta - - @staticmethod - def _get_llm_pos_ids_for_vision( - start_idx: int, - vision_idx: int, - spatial_merge_size: int, - t_index: list[int], - grid_hs: torch.Tensor, - grid_ws: torch.Tensor, - ) -> torch.Tensor: - llm_pos_ids_list = [] - llm_grid_h = grid_hs[vision_idx] // spatial_merge_size - llm_grid_w = grid_ws[vision_idx] // spatial_merge_size - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(len(t_index), -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(len(t_index), llm_grid_h, -1) - .flatten() - ) - t_index_tensor = ( - torch.Tensor(t_index) - .to(llm_grid_h.device) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - .long() - .flatten() - ) - _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index]) - llm_pos_ids_list.append(_llm_pos_ids + start_idx) - llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) - return llm_pos_ids - - @staticmethod - def _split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]: - ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)] - for num in lst: - index = num // interval - ranges[index].append(num) - return ranges - @staticmethod def get_next_input_positions( mrope_position_delta: int, @@ -1420,56 +404,3 @@ def get_next_input_positions_tensor( dtype=out.dtype, ) out[:, out_offset : out_offset + num_new_tokens] = values - - @classmethod - def omni_get_updates_use_audio_in_video( - cls, - thinker_config: PretrainedConfig, - audio_len: int, - video_grid_thw: Union[list[int], torch.Tensor], - video_second_per_grid_t: float, - ) -> list[int]: - """Get video prompt updates when `use_audio_in_video` is True. - - In this case, audio and vision update ids will be split into - chunks and interleaved (details in `_omni_get_input_positions_tensor`). - - <|video_bos|><|VIDEO|><|video_eos|> => - <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|> - """ - - audio_token_id = thinker_config.audio_token_index - video_token_id = thinker_config.video_token_index - audio_start_token_id = thinker_config.audio_start_token_id - audio_end_token_id = thinker_config.audio_end_token_id - seconds_per_chunk = thinker_config.seconds_per_chunk - spatial_merge_size = thinker_config.vision_config.spatial_merge_size - tokens_per_second = getattr( - thinker_config.vision_config, "tokens_per_second", 25 - ) - - grid_t = video_grid_thw[0] - grid_h = video_grid_thw[1] - grid_w = video_grid_thw[2] - t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) - t_index = ( - torch.arange(grid_t) * video_second_per_grid_t * tokens_per_second - ).long() - t_index_split_chunk = cls._split_list_into_ranges(t_index, t_ntoken_per_chunk) - - updates = [audio_start_token_id] - added_audio_len = 0 - for t_chunk in t_index_split_chunk: - vision_ntoken_per_chunk = ( - len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2) - ) - updates.extend([video_token_id] * vision_ntoken_per_chunk) - - audio_chunk_size = min(t_ntoken_per_chunk, audio_len - added_audio_len) - updates.extend(audio_chunk_size * [audio_token_id]) - added_audio_len += audio_chunk_size - if added_audio_len < audio_len: - updates.extend((audio_len - added_audio_len) * [audio_token_id]) - updates.extend([audio_end_token_id]) - - return updates diff --git a/vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py index 560fb100413d..031a12fceba6 100644 --- a/vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -21,7 +20,7 @@ def __init__( is_neox_style: bool, scaling_factor: float, dtype: torch.dtype, - mixed_b: Optional[float] = None, + mixed_b: float | None = None, ) -> None: self.scaling_factor = scaling_factor self.mixed_b = mixed_b diff --git a/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py b/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py index 02ad142d676b..2a42e3bd00ec 100644 --- a/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math -from typing import Optional import torch import torch.nn as nn @@ -26,8 +25,8 @@ def __init__( dtype: torch.dtype, short_factor: list[float], long_factor: list[float], - short_mscale: Optional[float] = None, - long_mscale: Optional[float] = None, + short_mscale: float | None = None, + long_mscale: float | None = None, ): super().__init__() @@ -106,9 +105,9 @@ def forward( self, positions: torch.Tensor, query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert key is not None query = query.view(*query.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size) diff --git a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py index 223350d43267..a01d14f7b3a1 100644 --- a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py +++ b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py @@ -5,7 +5,7 @@ import vllm.envs as envs from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def is_rocm_triton_rotary_embedding_enabled() -> bool: diff --git a/vllm/model_executor/layers/shared_fused_moe/__init__.py b/vllm/model_executor/layers/shared_fused_moe/__init__.py deleted file mode 100644 index b047e9cad04a..000000000000 --- a/vllm/model_executor/layers/shared_fused_moe/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.model_executor.layers.shared_fused_moe.shared_fused_moe import SharedFusedMoE - -__all__ = ["SharedFusedMoE"] diff --git a/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py deleted file mode 100644 index a8b09a5c3cdb..000000000000 --- a/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py +++ /dev/null @@ -1,59 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional - -import torch - -from vllm.distributed import tensor_model_parallel_all_reduce -from vllm.model_executor.layers.fused_moe.layer import FusedMoE - - -# TODO(bnell): Add shared + fused combo function? e.g. + -class SharedFusedMoE(FusedMoE): - """ - A FusedMoE operation that also computes the results of shared experts. - If an all2all communicator is being used the shared expert computation - can be interleaved with the fused all2all dispatch communication step. - """ - - def __init__( - self, - shared_experts: torch.nn.Module, - use_overlapped: bool = True, - **kwargs, - ): - super().__init__(**kwargs) - self._shared_experts = shared_experts - self.use_overlapped = use_overlapped - - @property - def shared_experts(self) -> Optional[torch.nn.Module]: - return self._shared_experts if self.use_overlapped else None - - def forward( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - if not self.use_overlapped: - shared_out = self._shared_experts(hidden_states) - - # Reduce outputs if necessary, since the MLP should - # have been created with reduce_results=False. - if ( - self.reduce_results - and self.tp_size > 1 - and self.must_reduce_shared_expert_outputs() - ): - shared_out = tensor_model_parallel_all_reduce(shared_out) - - fused_out = super().forward( - hidden_states=hidden_states, - router_logits=router_logits, - ) - else: - shared_out, fused_out = super().forward( - hidden_states=hidden_states, - router_logits=router_logits, - ) - return shared_out, fused_out diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index e522cc450d6b..e6b6a70afd97 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -2,14 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utility methods for model layers.""" -from typing import Callable, Optional +from collections.abc import Callable import torch from vllm import _custom_ops as ops from vllm import envs from vllm.platforms import CpuArchEnum, current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def shuffle_weight(w: torch.Tensor) -> torch.Tensor: @@ -95,13 +95,13 @@ def default_unquantized_gemm( layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ): return torch.nn.functional.linear(x, weight, bias) def rocm_unquantized_gemm_impl( - x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None ) -> torch.Tensor: from vllm.platforms.rocm import on_gfx9 @@ -131,7 +131,7 @@ def rocm_unquantized_gemm_impl( def rocm_unquantized_gemm_impl_fake( - x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None ) -> torch.Tensor: return x.new_empty((*x.shape[:-1], weight.shape[0])) @@ -140,7 +140,7 @@ def rocm_unquantized_gemm( layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias) @@ -178,9 +178,9 @@ def dispatch_cpu_unquantized_gemm( ) if remove_weight: layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) - elif ops._supports_onednn and ( - current_platform.get_cpu_architecture() == CpuArchEnum.X86 - or ops.is_onednn_acl_supported() + elif ( + ops._supports_onednn + and current_platform.get_cpu_architecture() != CpuArchEnum.POWERPC ): origin_weight = layer.weight if remove_weight: @@ -197,7 +197,7 @@ def cpu_unquantized_gemm( layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ): return layer.cpu_linear(x, weight, bias) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index b7253c7f0e52..1abc3ad88455 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -3,7 +3,6 @@ from collections.abc import Sequence from dataclasses import dataclass -from typing import Optional import torch import torch.nn.functional as F @@ -65,7 +64,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) @@ -226,10 +225,10 @@ def __init__( self, num_embeddings: int, embedding_dim: int, - params_dtype: Optional[torch.dtype] = None, - org_num_embeddings: Optional[int] = None, + params_dtype: torch.dtype | None = None, + org_num_embeddings: int | None = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -347,7 +346,7 @@ def _get_indices( added_vocab_end_index, ) - def get_sharded_to_full_mapping(self) -> Optional[list[int]]: + def get_sharded_to_full_mapping(self) -> list[int] | None: """Get a mapping that can be used to reindex the gathered logits for sampling. @@ -515,10 +514,10 @@ def __init__( num_embeddings: int, embedding_dim: int, bias: bool = False, - params_dtype: Optional[torch.dtype] = None, - org_num_embeddings: Optional[int] = None, + params_dtype: torch.dtype | None = None, + org_num_embeddings: int | None = None, padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__( diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index df0d059594a7..301f2d00bf40 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Literal, Optional +from typing import Literal from torch import nn @@ -122,7 +122,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: def get_model( - *, vllm_config: VllmConfig, model_config: Optional[ModelConfig] = None + *, vllm_config: VllmConfig, model_config: ModelConfig | None = None ) -> nn.Module: loader = get_model_loader(vllm_config.load_config) if model_config is None: diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 6106a1ab8a85..94dfa478245d 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -11,8 +11,8 @@ from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, - set_default_torch_dtype, ) +from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 8c1ff0300b24..97c7a20bc4d5 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -6,8 +6,8 @@ import itertools import math import os -from collections.abc import Generator -from typing import Any, Callable, Optional +from collections.abc import Callable, Generator +from typing import Any import numpy as np import torch @@ -32,7 +32,7 @@ RowParallelLinear, ) from vllm.model_executor.model_loader.base_loader import BaseModelLoader -from vllm.model_executor.model_loader.utils import ParamMapping, set_default_torch_dtype +from vllm.model_executor.model_loader.utils import ParamMapping from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, @@ -48,6 +48,7 @@ set_weight_attrs, ) from vllm.platforms import current_platform +from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) @@ -88,7 +89,7 @@ def _get_weight_files( self, model_name_or_path: str, allowed_patterns: list[str], - revision: Optional[str] = None, + revision: str | None = None, ) -> tuple[str, list[str], str]: """Retrieve weight files. Download the files if necessary. @@ -122,7 +123,7 @@ def _get_weight_files( raise RuntimeError(f"No model weights found in: `{model_name_or_path}`") def _prepare_weights( - self, model_name_or_path: str, revision: Optional[str] + self, model_name_or_path: str, revision: str | None ) -> tuple[list[str], bool]: """Prepare weight files for the model.""" @@ -196,7 +197,7 @@ def _maybe_pool_model(module_name: str): def _get_quantized_weights_iterator( self, model_name_or_path: str, - revision: Optional[str], + revision: str | None, ) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str, Any]]: """Get an iterator to the model weights with bitsandbytes quantization, as well as the quantization state dictionary.""" @@ -542,8 +543,7 @@ def _verify_model_compatibility( ) quant_config = getattr(model_config.hf_config, "quantization_config", None) - if quant_config is not None: - quant_method = quant_config.get("quant_method") + if quant_config and (quant_method := quant_config.get("quant_method")): if quant_method == "bitsandbytes": self.pre_quant = True else: @@ -558,7 +558,7 @@ def _verify_model_compatibility( "Prequant BitsAndBytes models with tensor parallelism is not " "supported. Please try with pipeline parallelism." ) - if self.pre_quant: + if quant_config and self.pre_quant: self.load_8bit = quant_config.get("load_in_8bit", False) def _initialize_loader_state( diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 00944989a002..c06ac550a94a 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -5,7 +5,7 @@ import os import time from collections.abc import Generator, Iterable -from typing import Optional, cast +from typing import cast import torch from torch import nn @@ -47,7 +47,7 @@ class Source: model_or_path: str """The model ID or path.""" - revision: Optional[str] + revision: str | None """The optional model revision.""" prefix: str = "" @@ -56,7 +56,7 @@ class Source: fall_back_to_pt: bool = True """Whether .pt weights can be used.""" - allow_patterns_overrides: Optional[list[str]] = None + allow_patterns_overrides: list[str] | None = None """If defined, weights will load exclusively using these patterns.""" counter_before_loading_weights: float = 0.0 @@ -79,9 +79,9 @@ def __init__(self, load_config: LoadConfig): def _prepare_weights( self, model_name_or_path: str, - revision: Optional[str], + revision: str | None, fall_back_to_pt: bool, - allow_patterns_overrides: Optional[list[str]], + allow_patterns_overrides: list[str] | None, ) -> tuple[str, list[str], bool]: """Prepare weights for the model. @@ -311,9 +311,10 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: loaded_weights = load_weights_and_online_quantize(self, model, model_config) self.counter_after_loading_weights = time.perf_counter() - logger.info( + logger.info_once( "Loading weights took %.2f seconds", self.counter_after_loading_weights - self.counter_before_loading_weights, + scope="local", ) # We only enable strict check for non-quantized models # that have loaded weights tracking currently. diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 93dc754a571c..7db1fc167c4f 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -15,13 +15,13 @@ from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, - set_default_torch_dtype, ) from vllm.model_executor.model_loader.weight_utils import ( get_gguf_extra_tensor_names, get_gguf_weight_type_map, gguf_quant_weights_iterator, ) +from vllm.utils.torch_utils import set_default_torch_dtype class GGUFModelLoader(BaseModelLoader): @@ -72,6 +72,10 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): # hack: ggufs have a different name than transformers if model_type == "cohere": model_type = "command-r" + if model_type == "gemma3_text": + # Gemma3 models use "gemma3_text" in HuggingFace but + # "gemma3" in GGUF architecture naming + model_type = "gemma3" if model_type in ("deepseek_v3", "deepseek_v2"): model_type = "deepseek2" # GGUF layer map assumes that we will have a merged expert weights diff --git a/vllm/model_executor/model_loader/runai_streamer_loader.py b/vllm/model_executor/model_loader/runai_streamer_loader.py index 50a92edd1162..079e3168647b 100644 --- a/vllm/model_executor/model_loader/runai_streamer_loader.py +++ b/vllm/model_executor/model_loader/runai_streamer_loader.py @@ -3,7 +3,6 @@ # ruff: noqa: SIM117 import os from collections.abc import Generator -from typing import Optional import torch from torch import nn @@ -51,7 +50,7 @@ def __init__(self, load_config: LoadConfig): os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url def _prepare_weights( - self, model_name_or_path: str, revision: Optional[str] + self, model_name_or_path: str, revision: str | None ) -> list[str]: """Prepare weights for the model. diff --git a/vllm/model_executor/model_loader/sharded_state_loader.py b/vllm/model_executor/model_loader/sharded_state_loader.py index e65eb78819e2..d94dbd9f06e0 100644 --- a/vllm/model_executor/model_loader/sharded_state_loader.py +++ b/vllm/model_executor/model_loader/sharded_state_loader.py @@ -5,7 +5,7 @@ import glob import os from collections.abc import Generator -from typing import Any, Optional +from typing import Any import torch from torch import nn @@ -89,7 +89,7 @@ def get_end_ptr(tensor: torch.Tensor) -> int: result[k] = t return result - def _prepare_weights(self, model_name_or_path: str, revision: Optional[str]): + def _prepare_weights(self, model_name_or_path: str, revision: str | None): if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path): return model_name_or_path else: @@ -171,8 +171,8 @@ def iterate_over_files( def save_model( model: torch.nn.Module, path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, + pattern: str | None = None, + max_size: int | None = None, ) -> None: from safetensors.torch import save_file diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 9d58278f996b..2890a2c6d702 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -12,7 +12,7 @@ import time from collections.abc import Generator, MutableMapping from dataclasses import asdict, dataclass, field, fields -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union +from typing import TYPE_CHECKING, Any, ClassVar, Optional import regex as re import torch @@ -26,7 +26,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.platforms import current_platform -from vllm.utils import FlexibleArgumentParser, PlaceholderModule +from vllm.utils import FlexibleArgumentParser +from vllm.utils.import_utils import PlaceholderModule if TYPE_CHECKING: from vllm.engine.arg_utils import EngineArgs @@ -67,7 +68,7 @@ logger = init_logger(__name__) -def is_valid_deserialization_uri(uri: Optional[str]) -> bool: +def is_valid_deserialization_uri(uri: str | None) -> bool: if uri: scheme = uri.lower().split("://")[0] return scheme in {"s3", "http", "https"} or os.path.exists(uri) @@ -156,25 +157,23 @@ def wrapper(*args, **kwargs): @dataclass class TensorizerConfig(MutableMapping): - tensorizer_uri: Optional[str] = None - tensorizer_dir: Optional[str] = None - vllm_tensorized: Optional[bool] = None - verify_hash: Optional[bool] = None - num_readers: Optional[int] = None - encryption_keyfile: Optional[str] = None - s3_access_key_id: Optional[str] = None - s3_secret_access_key: Optional[str] = None - s3_endpoint: Optional[str] = None - lora_dir: Optional[str] = None - stream_kwargs: Optional[dict[str, Any]] = None - serialization_kwargs: Optional[dict[str, Any]] = None - deserialization_kwargs: Optional[dict[str, Any]] = None - _extra_serialization_attrs: Optional[dict[str, Any]] = field( - init=False, default=None - ) - model_class: Optional[type[torch.nn.Module]] = field(init=False, default=None) - hf_config: Optional[PretrainedConfig] = field(init=False, default=None) - dtype: Optional[Union[str, torch.dtype]] = field(init=False, default=None) + tensorizer_uri: str | None = None + tensorizer_dir: str | None = None + vllm_tensorized: bool | None = None + verify_hash: bool | None = None + num_readers: int | None = None + encryption_keyfile: str | None = None + s3_access_key_id: str | None = None + s3_secret_access_key: str | None = None + s3_endpoint: str | None = None + lora_dir: str | None = None + stream_kwargs: dict[str, Any] | None = None + serialization_kwargs: dict[str, Any] | None = None + deserialization_kwargs: dict[str, Any] | None = None + _extra_serialization_attrs: dict[str, Any] | None = field(init=False, default=None) + model_class: type[torch.nn.Module] | None = field(init=False, default=None) + hf_config: PretrainedConfig | None = field(init=False, default=None) + dtype: str | torch.dtype | None = field(init=False, default=None) _is_sharded: bool = field(init=False, default=False) _fields: ClassVar[tuple[str, ...]] _keys: ClassVar[frozenset[str]] @@ -362,9 +361,9 @@ def __delitem__(self, key, /): @dataclass class TensorizerArgs: - tensorizer_uri: Optional[str] = None - tensorizer_dir: Optional[str] = None - encryption_keyfile: Optional[str] = None + tensorizer_uri: str | None = None + tensorizer_dir: str | None = None + encryption_keyfile: str | None = None def __init__(self, tensorizer_config: TensorizerConfig): for k, v in tensorizer_config.items(): @@ -520,7 +519,7 @@ def init_tensorizer_model( ) -> nn.Module: assert tensorizer_config.hf_config is not None model_args = tensorizer_config.hf_config - model_args.torch_dtype = tensorizer_config.dtype + model_args.dtype = tensorizer_config.dtype assert tensorizer_config.model_class is not None # TODO: Do we need to consider old-style model class? with meta_tensor_mode(), set_current_vllm_config(vllm_config, check_compile=True): @@ -621,7 +620,7 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: def serialize_extra_artifacts( - tensorizer_args: TensorizerArgs, served_model_name: Union[str, list[str], None] + tensorizer_args: TensorizerArgs, served_model_name: str | list[str] | None ) -> None: if not isinstance(served_model_name, str): raise ValueError( diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index 5585a74f8926..2b3704cfebba 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -3,7 +3,6 @@ # ruff: noqa: SIM117 import copy from collections.abc import Generator -from typing import Union import torch from torch import nn @@ -23,8 +22,8 @@ from vllm.model_executor.model_loader.utils import ( get_model_architecture, initialize_model, - set_default_torch_dtype, ) +from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) @@ -140,7 +139,7 @@ def load_model( @staticmethod def save_model( model: torch.nn.Module, - tensorizer_config: Union[TensorizerConfig, dict], + tensorizer_config: TensorizerConfig | dict, model_config: ModelConfig, ) -> None: if isinstance(tensorizer_config, dict): diff --git a/vllm/model_executor/model_loader/tpu.py b/vllm/model_executor/model_loader/tpu.py index fc97003de8e3..fc142f1f07fa 100644 --- a/vllm/model_executor/model_loader/tpu.py +++ b/vllm/model_executor/model_loader/tpu.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time -from typing import Optional import torch import torch.nn as nn @@ -15,8 +14,8 @@ from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, - set_default_torch_dtype, ) +from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) @@ -30,7 +29,7 @@ def load_model( self, vllm_config: VllmConfig, model_config: ModelConfig, - mesh: Optional[xs.Mesh] = None, + mesh: xs.Mesh | None = None, ) -> nn.Module: # Initialize model and load weights on CPU. Then, during SPMD partition, # weights are sharded and transferred to TPUs. @@ -90,7 +89,7 @@ def load_model( ) return model - def _check_model_is_loaded(self, mesh: Optional[xs.Mesh], model: nn.Module) -> None: + def _check_model_is_loaded(self, mesh: xs.Mesh | None, model: nn.Module) -> None: """ Ensure the model is properly loaded. 1. All model parameters and buffers are on XLA device. diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index ba8d53c0ba14..ba708a098c0d 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -2,21 +2,19 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utilities for selecting and loading models.""" -import contextlib import inspect import warnings from contextlib import contextmanager from dataclasses import dataclass, field -from typing import Optional import torch from torch import nn from typing_extensions import assert_never from vllm.attention import Attention +from vllm.attention.layer import MLAAttention from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config from vllm.logger import init_logger -from vllm.model_executor.layers.linear import QKVCrossParallelLinear from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -28,26 +26,17 @@ try_create_mm_pooling_model_cls, ) from vllm.model_executor.models.interfaces import SupportsQuant, supports_multimodal -from vllm.utils import is_pin_memory_available +from vllm.utils.platform_utils import is_pin_memory_available logger = init_logger(__name__) -@contextlib.contextmanager -def set_default_torch_dtype(dtype: torch.dtype): - """Sets the default torch dtype to the given dtype.""" - old_dtype = torch.get_default_dtype() - torch.set_default_dtype(dtype) - yield - torch.set_default_dtype(old_dtype) - - def initialize_model( vllm_config: VllmConfig, *, prefix: str = "", - model_class: Optional[type[nn.Module]] = None, - model_config: Optional[ModelConfig] = None, + model_class: type[nn.Module] | None = None, + model_config: ModelConfig | None = None, ) -> nn.Module: """Initialize a model with the given configurations.""" if model_config is None: @@ -107,11 +96,6 @@ def process_weights_after_loading( maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config) for _, module in model.named_modules(): - if isinstance(module, QKVCrossParallelLinear): - # NOTE(Isotr0py): special case for cross QKV layer because - # q and kv proj aren't registered as submodules intentionally - module.process_weights_after_loading() - continue quant_method = getattr(module, "quant_method", None) if isinstance(quant_method, QuantizeMethodBase): # When quant methods need to process weights after loading @@ -122,11 +106,10 @@ def process_weights_after_loading( with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) - # Currently only used by MLA. - # NOTE: This intentionally happens after other modules so we can easily - # decompress the weights for MLA. + # Initialize post-load attention weights for both Attention and MLA. + # NOTE: Happens after other modules so we can easily decompress weights. for _, module in model.named_modules(): - if isinstance(module, Attention) and hasattr( + if isinstance(module, (Attention, MLAAttention)) and hasattr( module, "process_weights_after_loading" ): # TODO(lucas): see if there is a way to unify the signatures @@ -274,7 +257,7 @@ def __post_init__(self): index, ) - def get_sub_modules(self, module_name: str) -> Optional[tuple[str, list[str]]]: + def get_sub_modules(self, module_name: str) -> tuple[str, list[str]] | None: for key, value in self.packed_mapping.items(): if module_name.endswith(key): return key, value diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 5f83482bec3a..5a9faefa4d89 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -11,10 +11,10 @@ import tempfile import time from collections import defaultdict -from collections.abc import Generator +from collections.abc import Callable, Generator from contextlib import contextmanager from pathlib import Path -from typing import IO, Any, Callable, Optional, Union +from typing import IO, Any import filelock import huggingface_hub.constants @@ -34,7 +34,7 @@ get_quantization_config, ) from vllm.platforms import current_platform -from vllm.utils import PlaceholderModule +from vllm.utils.import_utils import PlaceholderModule try: from runai_model_streamer import SafetensorsStreamer @@ -85,7 +85,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, disable=True) -def get_lock(model_name_or_path: Union[str, Path], cache_dir: Optional[str] = None): +def get_lock(model_name_or_path: str | Path, cache_dir: str | None = None): lock_dir = cache_dir or temp_dir model_name_or_path = str(model_name_or_path) os.makedirs(os.path.dirname(lock_dir), exist_ok=True) @@ -100,7 +100,7 @@ def get_lock(model_name_or_path: Union[str, Path], cache_dir: Optional[str] = No @contextmanager def atomic_writer( - filepath: Union[str, Path], mode: str = "w", encoding: Optional[str] = None + filepath: str | Path, mode: str = "w", encoding: str | None = None ) -> Generator[IO]: """ Context manager that provides an atomic file writing routine. @@ -143,11 +143,11 @@ def atomic_writer( def maybe_download_from_modelscope( model: str, - revision: Optional[str] = None, - download_dir: Optional[str] = None, - ignore_patterns: Optional[Union[str, list[str]]] = None, - allow_patterns: Optional[Union[list[str], str]] = None, -) -> Optional[str]: + revision: str | None = None, + download_dir: str | None = None, + ignore_patterns: str | list[str] | None = None, + allow_patterns: list[str] | str | None = None, +) -> str | None: """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. Returns the path to the downloaded model, or None if the model is not @@ -370,10 +370,10 @@ def get_sparse_attention_config( def download_weights_from_hf( model_name_or_path: str, - cache_dir: Optional[str], + cache_dir: str | None, allow_patterns: list[str], - revision: Optional[str] = None, - ignore_patterns: Optional[Union[str, list[str]]] = None, + revision: str | None = None, + ignore_patterns: str | list[str] | None = None, ) -> str: """Download model weights from Hugging Face Hub. @@ -416,7 +416,7 @@ def download_weights_from_hf( e, ) - logger.info("Using model weights format %s", allow_patterns) + logger.debug("Using model weights format %s", allow_patterns) # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. with get_lock(model_name_or_path, cache_dir): @@ -448,8 +448,8 @@ def download_weights_from_hf( def download_safetensors_index_file_from_hf( model_name_or_path: str, index_file: str, - cache_dir: Optional[str], - revision: Optional[str] = None, + cache_dir: str | None, + revision: str | None = None, ) -> None: """Download hf safetensors index file from Hugging Face Hub. @@ -540,7 +540,7 @@ def enable_tqdm(use_tqdm_on_load: bool): def np_cache_weights_iterator( model_name_or_path: str, - cache_dir: Optional[str], + cache_dir: str | None, hf_folder: str, hf_weights_files: list[str], use_tqdm_on_load: bool, @@ -746,7 +746,7 @@ def fastsafetensors_weights_iterator( def pt_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, - pt_load_map_location: Union[str, dict[str, str]] = "cpu", + pt_load_map_location: str | dict[str, str] = "cpu", ) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model bin/pt files.""" for bin_file in tqdm( @@ -765,7 +765,7 @@ def pt_weights_iterator( def multi_thread_pt_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, - pt_load_map_location: Union[str, dict[str, str]] = "cpu", + pt_load_map_location: str | dict[str, str] = "cpu", max_workers: int = 4, ) -> Generator[tuple[str, torch.Tensor], None, None]: """Multi-Thread iterate over the weights in the model bin/pt files.""" @@ -985,7 +985,7 @@ def initialize_dummy_weights( param.uniform_(low, high, generator=generator) -def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: +def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: """Remap the name of FP8 k/v_scale parameters. This function handles the remapping of FP8 k/v_scale parameter names. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index b56cb3340048..9f8dd042bf83 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -8,14 +8,12 @@ SupportsMultiModal, SupportsPP, SupportsTranscription, - SupportsV0Only, has_inner_state, supports_lora, supports_mrope, supports_multimodal, supports_pp, supports_transcription, - supports_v0_only, ) from .interfaces_base import ( VllmModelForPooling, @@ -43,6 +41,4 @@ "supports_pp", "SupportsTranscription", "supports_transcription", - "SupportsV0Only", - "supports_v0_only", ] diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index fd8a0b87e43e..7990024c55d0 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -4,7 +4,7 @@ import ast import inspect from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast +from typing import TYPE_CHECKING, Any, TypeVar, cast import torch import torch.nn as nn @@ -13,7 +13,10 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.config import VerifyAndUpdateConfig -from vllm.transformers_utils.config import get_hf_file_bytes, get_hf_file_to_dict +from vllm.transformers_utils.config import ( + get_hf_file_bytes, + try_get_dense_modules, +) from .interfaces_base import VllmModelForPooling, is_pooling_model @@ -32,46 +35,28 @@ ] -def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]: +def _load_st_projector(model_config: "ModelConfig") -> nn.Module | None: """Load Sentence-Transformers Dense projection layers.""" - try: - modules = get_hf_file_to_dict( - "modules.json", model_config.model, model_config.revision - ) - if not modules: - return None - - if isinstance(modules, dict): - modules = modules.get("modules", []) + dense_modules = try_get_dense_modules( + model_config.model, revision=model_config.revision + ) - dense_modules = [ - m for m in modules if m.get("type") == "sentence_transformers.models.Dense" - ] - if not dense_modules: - return None + if dense_modules is None: + return + try: layers = [] - for module in dense_modules: - folder = module.get("path", "") - - config_path = f"{folder}/config.json" if folder else "config.json" - layer_config = get_hf_file_to_dict( - config_path, model_config.model, model_config.revision - ) - if not layer_config: - continue - + for layer_config in dense_modules: + folder = layer_config["folder"] linear = nn.Linear( - layer_config.get("in_features", 768), - layer_config.get("out_features", 768), + layer_config["in_features"], + layer_config["out_features"], bias=layer_config.get("bias", True), dtype=model_config.head_dtype, ) - if not _load_dense_weights(linear, folder, model_config): continue - layers.append(linear) if act_name := layer_config.get("activation_function"): layers.append(get_act_fn(act_name)) @@ -265,7 +250,7 @@ def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_embed": Pooler.for_token_embed(pooler_config), "embed": Pooler.for_embed(pooler_config), }, ) @@ -294,79 +279,52 @@ def as_seq_cls_model(cls: _T) -> _T: # Lazy import from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.pooler import ( - ClassifierPooler, DispatchPooler, Pooler, - PoolingMethod, - PoolingType, ) from vllm.model_executor.models.interfaces import SupportsCrossEncoding - from vllm.sequence import IntermediateTensors - from .utils import get_model_hidden_size, maybe_prefix + from .utils import maybe_prefix class ModelForSequenceClassification( _create_pooling_model_cls(cls), SupportsCrossEncoding ): def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): - config = vllm_config.model_config.hf_config + text_config = vllm_config.model_config.hf_config.get_text_config() + model_config = vllm_config.model_config quant_config = vllm_config.quant_config - hidden_size = get_model_hidden_size(config) self.score = ReplicatedLinear( - hidden_size, - config.num_labels, + model_config.hidden_size, + text_config.num_labels, bias=False, - params_dtype=torch.float32, + params_dtype=vllm_config.model_config.head_dtype, quant_config=quant_config, + return_bias=False, prefix=maybe_prefix(prefix, "score"), ) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - pooling_type_str = pooler_config.pooling_type - assert pooling_type_str is not None - pooling_type = PoolingType[pooling_type_str] - self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), - "classify": ClassifierPooler( - pooling=PoolingMethod.from_pooling_type(pooling_type), - classifier=self._classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config - ), + "token_classify": Pooler.for_token_classify( + pooler_config, classifier=self.score ), - "score": ClassifierPooler( - pooling=PoolingMethod.from_pooling_type(pooling_type), - classifier=self._classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config - ), + "classify": Pooler.for_classify( + pooler_config, classifier=self.score, act_fn="classify" + ), + "score": Pooler.for_classify( + pooler_config, classifier=self.score, act_fn="score" ), } ) - def _classifier(self, x: torch.Tensor): - x, _ = self.score(x.float()) - return x - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return super().forward( - input_ids, positions, intermediate_tensors, inputs_embeds - ) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - tokens = getattr(self.config, "classifier_from_token", None) - method = getattr(self.config, "method", None) + text_config = self.config.get_text_config() + tokens = getattr(text_config, "classifier_from_token", None) + method = getattr(text_config, "method", None) if tokens is None and method is None: return super().load_weights(weights) @@ -399,13 +357,20 @@ def as_reward_model(cls: _T) -> _T: # Lazy import from vllm.model_executor.layers.pooler import DispatchPooler, Pooler + from .interfaces_base import default_pooling_type + + @default_pooling_type("ALL") class ModelForReward(_create_pooling_model_cls(cls)): def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}, + { + "token_classify": Pooler.for_token_classify( + pooler_config=pooler_config + ) + } ) ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward") @@ -416,9 +381,9 @@ def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): class SequenceClassificationConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: - config = vllm_config.model_config.hf_config - method = getattr(config, "method", None) - tokens = getattr(config, "classifier_from_token", None) + text_config = vllm_config.model_config.hf_config.get_text_config() + method = getattr(text_config, "method", None) + tokens = getattr(text_config, "classifier_from_token", None) if method is None: return @@ -428,13 +393,13 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: if method == "from_2_way_softmax": assert len(tokens) == 2 - config.num_labels = 1 + text_config.num_labels = 1 else: - config.num_labels = len(tokens) + text_config.num_labels = len(tokens) # `llm as reranker` defaults to not using pad_token - use_pad_token = getattr(config, "use_pad_token", False) - config.use_pad_token = use_pad_token + use_pad_token = getattr(text_config, "use_pad_token", False) + text_config.use_pad_token = use_pad_token def load_weights_using_from_2_way_softmax( @@ -443,24 +408,31 @@ def load_weights_using_from_2_way_softmax( # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3 from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader - from vllm.model_executor.models.utils import AutoWeightsLoader model_config = model.vllm_config.model_config + quant_config = model.vllm_config.quant_config + text_config = model.config.get_text_config() - tokens = getattr(model.config, "classifier_from_token", []) + tokens = getattr(text_config, "classifier_from_token", []) tokens = cast(list[int], tokens) assert len(tokens) == 2 - if model.config.tie_word_embeddings: - model.lm_head = model.model.embed_tokens - else: - quant_config = model.vllm_config.quant_config - model.lm_head = ParallelLMHead( - model.config.vocab_size, model.config.hidden_size, quant_config=quant_config + model.lm_head = ParallelLMHead( + text_config.vocab_size, text_config.hidden_size, quant_config=quant_config + ) + if text_config.tie_word_embeddings: + # embed_tokens is the assumed name for input embeddings. If the model does not + # have this attribute, we fallback to get_input_embeddings(), which is used by + # the Transformers backend. + embed_tokens = ( + model.model.embed_tokens + if hasattr(model.model, "embed_tokens") + else model.model.get_input_embeddings() ) + model.lm_head = model.lm_head.tie_weights(embed_tokens) - loader = AutoWeightsLoader(model) - loaded_weights = loader.load_weights(weights) + # Skip ModelForSequenceClassification in MRO to avoid infinite recursion + loaded_weights = type(model).__mro__[1].load_weights(model, weights) from vllm.transformers_utils.tokenizer import get_tokenizer @@ -490,23 +462,31 @@ def load_weights_using_from_2_way_softmax( def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Tensor]]): from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader - from vllm.model_executor.models.utils import AutoWeightsLoader model_config = model.vllm_config.model_config - tokens = getattr(model.config, "classifier_from_token", []) + quant_config = model.vllm_config.quant_config + text_config = model.config.get_text_config() + + tokens = getattr(text_config, "classifier_from_token", []) tokens = cast(list[int], tokens) assert len(tokens) > 0 - if model.config.tie_word_embeddings: - model.lm_head = model.model.embed_tokens - else: - quant_config = model.vllm_config.quant_config - model.lm_head = ParallelLMHead( - model.config.vocab_size, model.config.hidden_size, quant_config=quant_config + model.lm_head = ParallelLMHead( + text_config.vocab_size, text_config.hidden_size, quant_config=quant_config + ) + if text_config.tie_word_embeddings: + # embed_tokens is the assumed name for input embeddings. If the model does not + # have this attribute, we fallback to get_input_embeddings(), which is used by + # the Transformers backend. + embed_tokens = ( + model.model.embed_tokens + if hasattr(model.model, "embed_tokens") + else model.model.get_input_embeddings() ) + model.lm_head = model.lm_head.tie_weights(embed_tokens) - loader = AutoWeightsLoader(model) - loaded_weights = loader.load_weights(weights) + # Skip ModelForSequenceClassification in MRO to avoid infinite recursion + loaded_weights = type(model).__mro__[1].load_weights(model, weights) from vllm.transformers_utils.tokenizer import get_tokenizer @@ -547,7 +527,7 @@ def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]): # - GemmaForCausalLM # - bge-reranker-v2-gemma - config = model.vllm_config.model_config.hf_config - method = getattr(config, "method", None) + text_config = model.vllm_config.model_config.hf_config.get_text_config() + method = getattr(text_config, "method", None) assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported" return SEQ_CLS_LOAD_METHODS[method](model, weights) diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py index 2423ad5b0c3a..5872e8196ead 100644 --- a/vllm/model_executor/models/aimv2.py +++ b/vllm/model_executor/models/aimv2.py @@ -4,7 +4,6 @@ # A modified implementation of the AIMv2 Transformer # inserted here also the image tokenizer used by Ovis2 from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn @@ -165,7 +164,7 @@ def __init__( config: AIMv2Config, quant_config: QuantizationConfig, *, - require_post_norm: Optional[bool] = None, + require_post_norm: bool | None = None, prefix: str = "", ): super().__init__() @@ -196,7 +195,7 @@ def __init__( config: AIMv2Config, quant_config: QuantizationConfig, *, - require_post_norm: Optional[bool] = None, + require_post_norm: bool | None = None, prefix: str = "", ): super().__init__() diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py index 743207082721..72e5ddcf1abe 100644 --- a/vllm/model_executor/models/apertus.py +++ b/vllm/model_executor/models/apertus.py @@ -26,7 +26,8 @@ """Inference-only Apertus model compatible with HuggingFace weights.""" from collections.abc import Iterable -from typing import Any, Optional, Union +from itertools import islice +from typing import Any import torch from torch import nn @@ -76,7 +77,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, prefix: str = "", reduce_results: bool = True, @@ -119,12 +120,12 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, bias_o_proj: bool = False, - cache_config: Optional[CacheConfig] = None, + cache_config: CacheConfig | None = None, prefix: str = "", attn_type: str = AttentionType.DECODER, ) -> None: @@ -224,8 +225,8 @@ def forward( def _init_rotary_emb( self, config: ApertusConfig, - rope_scaling: Optional[dict[str, Any]], - quant_config: Optional[QuantizationConfig], + rope_scaling: dict[str, Any] | None, + quant_config: QuantizationConfig | None, ) -> None: is_neox_style = True is_gguf = quant_config and quant_config.get_name() == "gguf" @@ -247,8 +248,8 @@ class ApertusDecoderLayer(nn.Module): def __init__( self, config: ApertusConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -315,7 +316,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -393,13 +394,11 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[ - torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]] - ]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -412,7 +411,9 @@ def forward( residual = intermediate_tensors["residual"] aux_hidden_states = [] - for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]): + for idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer) + ): if idx in self.aux_hidden_state_layers: aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) @@ -583,9 +584,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: model_output = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -594,7 +595,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py index 634e94b16814..08bf1a6aad75 100644 --- a/vllm/model_executor/models/arcee.py +++ b/vllm/model_executor/models/arcee.py @@ -10,7 +10,7 @@ from collections.abc import Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -52,7 +52,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[Any] = None, + quant_config: Any | None = None, bias: bool = False, prefix: str = "", reduce_results: bool = True, @@ -98,8 +98,8 @@ class ArceeDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, - cache_config: Optional[Any] = None, - quant_config: Optional[Any] = None, + cache_config: Any | None = None, + quant_config: Any | None = None, prefix: str = "", ) -> None: super().__init__() @@ -165,7 +165,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self-Attention block if residual is None: @@ -247,13 +247,11 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[ - torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]] - ]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: # Embedding lookup (on first pipeline rank) if get_pp_group().is_first_rank: hidden_states = ( @@ -415,9 +413,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: model_output = self.model( input_ids=input_ids, positions=positions, @@ -426,7 +424,7 @@ def forward( ) return model_output - def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: # Compute final logits from hidden states (last pipeline rank only) logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 760df1cef82b..e0b6444c9183 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -4,7 +4,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -63,7 +62,7 @@ def __init__( config: ArcticConfig, expert_id: int = -1, is_residual_mlp: bool = False, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, reduce_results: bool = True, prefix: str = "", ): @@ -107,9 +106,9 @@ class ArcticMoE(nn.Module): def __init__( self, config: ArcticConfig, - tp_size: Optional[int] = None, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, + tp_size: int | None = None, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, reduce_results: bool = True, prefix: str = "", ): @@ -265,8 +264,8 @@ class ArcticAttention(nn.Module): def __init__( self, config: ArcticConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -342,8 +341,8 @@ class ArcticDecoderLayer(nn.Module): def __init__( self, config: ArcticConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -443,9 +442,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -499,9 +498,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -510,7 +509,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 7db118ca0745..222a42579054 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal import torch import torch.nn as nn @@ -13,7 +13,7 @@ from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -71,7 +71,7 @@ class AriaImagePixelInputs(TensorSchema): ] pixel_mask: Annotated[ - Optional[torch.Tensor], + torch.Tensor | None, TensorShape("bn", "h", "w"), ] @@ -82,7 +82,7 @@ class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant): def __init__( self, config: Idefics2VisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__(config, quant_config=quant_config, prefix=prefix) @@ -180,7 +180,7 @@ def __init__(self, config: AriaConfig) -> None: def forward( self, x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, + attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: batch_size, num_patches = x.shape[0], x.shape[1] @@ -206,7 +206,7 @@ def forward( return out -class AriaFusedMoE(FusedMoE): +class AriaFusedMoE(SharedFusedMoE): def weight_loader( self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str ) -> None: @@ -250,7 +250,7 @@ class AriaTextMoELayer(nn.Module): def __init__( self, config: AriaTextConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, prefix: str = "", ) -> None: super().__init__() @@ -260,7 +260,16 @@ def __init__( torch.empty((self.config.moe_num_experts, self.config.hidden_size)) ) + self.shared_experts = LlamaMLP( + config.hidden_size, + config.intermediate_size * config.moe_num_shared_experts, + "silu", + quant_config=quant_config, + bias=config.mlp_bias, + ) + self.experts = AriaFusedMoE( + shared_experts=self.shared_experts, num_experts=config.moe_num_experts, top_k=config.moe_topk, hidden_size=config.hidden_size, @@ -269,13 +278,6 @@ def __init__( reduce_results=True, prefix=f"{prefix}.experts", ) - self.shared_experts = LlamaMLP( - config.hidden_size, - config.intermediate_size * config.moe_num_shared_experts, - "silu", - quant_config=quant_config, - bias=config.mlp_bias, - ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ @@ -291,12 +293,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_output = torch.nn.functional.linear(hidden_states, self.router_weight) - hidden_states_copy = hidden_states.clone() - # NOTE: hidden_states will be modified inplace by `FusedMoE` sparse_expert_output = self.experts(hidden_states, router_output) - shared_expert_output = self.shared_experts(hidden_states_copy) - return sparse_expert_output + shared_expert_output + if self.shared_experts is not None: + return sparse_expert_output[0] + sparse_expert_output[1] + else: + return sparse_expert_output class AriaTextDecoderLayer(LlamaDecoderLayer): @@ -413,7 +415,7 @@ def get_vision_config(self): def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor(AriaProcessor, **kwargs) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_num_image_tokens(self) -> int: @@ -434,7 +436,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: vision_config = self.info.get_vision_config() @@ -515,7 +517,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|fim_prefix|><|img|><|fim_suffix|>" @@ -560,7 +562,7 @@ def __init__( def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[AriaImagePixelInputs]: + ) -> AriaImagePixelInputs | None: pixel_values = kwargs.pop("pixel_values", None) pixel_mask = kwargs.pop("pixel_mask", None) @@ -575,8 +577,8 @@ def _parse_and_validate_image_input( def _create_patch_attention_mask( self, - pixel_mask: Optional[torch.Tensor], - ) -> Optional[torch.Tensor]: + pixel_mask: torch.Tensor | None, + ) -> torch.Tensor | None: if pixel_mask is None: return None @@ -626,10 +628,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if inputs_embeds is None: multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings( diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 6e93de524e48..839ab5947e09 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal import torch from torch import nn @@ -139,7 +139,7 @@ def get_hf_processor(self, **kwargs: object) -> AyaVisionProcessor: def get_image_processor(self, **kwargs: object) -> GotOcr2ImageProcessor: return self.get_hf_processor(**kwargs).image_processor - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_image_size_with_most_features(self) -> ImageSize: @@ -187,7 +187,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) image_size = self.info.get_image_size_with_most_features() @@ -331,7 +331,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<image>" @@ -375,7 +375,7 @@ def _image_pixels_to_features( self, vision_tower: SiglipVisionModel, pixel_values: torch.Tensor, - ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + ) -> torch.Tensor | tuple[torch.Tensor, ...]: return vision_tower( pixel_values.to(dtype=vision_tower.dtype), feature_select_strategy=self.config.vision_feature_select_strategy, @@ -395,7 +395,7 @@ def _process_image_input( def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[AyaVisionImagePixelInputs]: + ) -> AyaVisionImagePixelInputs | None: pixel_values = kwargs.pop("pixel_values", None) num_patches = kwargs.pop("num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -428,10 +428,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None @@ -446,5 +446,5 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index a8f0e5993e2b..ccf32c9ee1ac 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -24,7 +24,6 @@ import math from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -98,7 +97,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -130,8 +129,8 @@ def __init__( position_embedding: str, rope_theta: float = 10000, max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -213,8 +212,8 @@ def __init__( self, config: PretrainedConfig, position_embedding: str, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -246,7 +245,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -305,9 +304,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -422,9 +421,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -433,7 +432,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 3911ba599069..1549c653482f 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -26,7 +26,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch import torch.nn.functional as F @@ -43,7 +42,7 @@ tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -75,8 +74,8 @@ class BailingAttention(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, reduce_results: bool = True, prefix: str = "", ): @@ -87,13 +86,12 @@ def __init__( tp_size = get_tensor_model_parallel_world_size() assert self.total_num_heads % tp_size == 0 - assert self.total_kv_heads % tp_size == 0 assert self.total_num_heads >= self.total_kv_heads self.num_heads = self.total_num_heads // tp_size self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads) self.q_size_per_rank = self.head_dim * self.num_heads - self.num_kv_heads = self.total_kv_heads // tp_size + self.num_kv_heads = max(1, self.total_kv_heads // tp_size) self.kv_size_per_rank = self.num_kv_heads * self.head_dim self.scale = self.head_dim**-0.5 self.use_qk_norm = getattr(config, "use_qk_norm", False) @@ -184,8 +182,8 @@ def __init__( self, intermediate_size: int, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: Optional[bool] = True, + quant_config: QuantizationConfig | None = None, + reduce_results: bool | None = True, prefix: str = "", ) -> None: super().__init__() @@ -218,8 +216,8 @@ def __init__( self, intermediate_size: int, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: Optional[bool] = True, + quant_config: QuantizationConfig | None = None, + reduce_results: bool | None = True, prefix: str = "", ): super().__init__() @@ -276,22 +274,6 @@ def __init__( # default value for scoring_func self.score_function = "softmax" - self.experts = FusedMoE( - num_experts=self.num_experts, - top_k=self.top_k, - hidden_size=self.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=self.norm_expert_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts", - scoring_func=self.score_function, - e_score_correction_bias=self.gate.expert_bias, - num_expert_group=self.n_group, - topk_group=self.topk_group, - use_grouped_topk=self.use_grouped_topk, - ) - if self.num_shared_experts > 0: if hasattr(config, "moe_shared_expert_intermediate_size"): intermediate_size = config.moe_shared_expert_intermediate_size @@ -308,11 +290,27 @@ def __init__( else: self.shared_experts = None + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=self.norm_expert_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + scoring_func=self.score_function, + e_score_correction_bias=self.gate.expert_bias, + num_expert_group=self.n_group, + topk_group=self.topk_group, + use_grouped_topk=self.use_grouped_topk, + ) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_size) - if self.shared_experts: - shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states.to(self.router_dtype)) router_logits = router_logits.to(hidden_states.dtype) @@ -321,9 +319,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states=hidden_states, router_logits=router_logits ) + if self.shared_experts is not None: + shared_output, final_hidden_states = final_hidden_states + else: + shared_output = None + final_hidden_states *= self.routed_scaling_factor - if self.shared_experts: + if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: @@ -335,8 +338,8 @@ class BailingMoeBlock(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -365,7 +368,7 @@ def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> torch.Tensor: if residual is None: residual = hidden_states @@ -442,9 +445,9 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -475,7 +478,7 @@ def forward( return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return FusedMoE.make_expert_params_mapping( + return SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", @@ -614,9 +617,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: model_output = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -625,7 +628,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 42c1c7be1a75..1a06f0659235 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -4,7 +4,6 @@ # Added by the IBM Team, 2024 from collections.abc import Iterable -from typing import Optional import torch from torch import nn @@ -52,7 +51,7 @@ class BambaMLP(nn.Module): def __init__( self, config: BambaConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, ) -> None: super().__init__() @@ -87,9 +86,9 @@ def __init__( self, config: BambaConfig, layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -119,7 +118,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, **kwargs, ): if residual is None: @@ -141,9 +140,9 @@ def __init__( self, config: BambaConfig, layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -235,7 +234,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, **kwargs, ): if residual is None: @@ -314,8 +313,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -497,8 +496,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, ): hidden_states = self.model( @@ -510,7 +509,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/bee.py b/vllm/model_executor/models/bee.py new file mode 100644 index 000000000000..4f0342df404b --- /dev/null +++ b/vllm/model_executor/models/bee.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Mapping + +import torch +import torch.nn as nn +from transformers.activations import GELUActivation + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalDataDict + +from .llava_next import ( + LlavaDummyInputsBuilder, + LlavaNextMultiModalProcessor, + LlavaNextProcessingInfo, +) +from .llava_onevision import LlavaOnevisionForConditionalGeneration +from .utils import WeightsMapper + + +class BeeProcessingInfo(LlavaNextProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(**kwargs) + + def _get_num_unpadded_features( + self, + *, + original_height: int, + original_width: int, + npatches: int, + num_patch_height: int, + num_patch_width: int, + ) -> tuple[int, int]: + """Override to use correct max_num_patches from vision_aspect_ratio.""" + import math + + current_height = npatches * num_patch_height + current_width = npatches * num_patch_width + + aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if aspect_ratio > current_aspect_ratio: + new_height = int( + round(original_height * (current_width / original_width), 7) + ) + padding = (current_height - new_height) // 2 + current_height = current_height - (2 * padding) + else: + new_width = int( + round(original_width * (current_height / original_height), 7) + ) + padding = (current_width - new_width) // 2 + current_width = current_width - (2 * padding) + + unpadded_features = current_height * current_width + newline_features = current_height + + # Get max_num_patches from vision_aspect_ratio config + hf_config = self.get_hf_config() + vision_aspect_ratio = getattr(hf_config, "vision_aspect_ratio", "anyres_max_9") + max_num_patches = int(vision_aspect_ratio.replace("anyres_max_", "")) + + ratio = math.sqrt( + current_height * current_width / (max_num_patches * npatches**2) + ) + if ratio > 1.1: + height_factor = int(current_height // ratio) + width_factor = int(current_width // ratio) + unpadded_features = height_factor * width_factor + newline_features = height_factor + + return (unpadded_features, newline_features) + + +class BeeDummyInputsBuilder(LlavaDummyInputsBuilder[BeeProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + image_token = "<image>" + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None + + return { + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + } + + +class BeeMultiModalProjector(nn.Module): + def __init__(self, config): + super().__init__() + self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=1e-06) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, + config.text_config.hidden_size * 4, + bias=True, + ) + self.act = GELUActivation() + self.linear_2 = nn.Linear( + config.text_config.hidden_size * 4, + config.text_config.hidden_size, + bias=True, + ) + + def forward(self, image_feature: torch.Tensor) -> torch.Tensor: + image_feature = self.pre_norm(image_feature) + hidden_states = self.linear_1(image_feature) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + LlavaNextMultiModalProcessor, + info=BeeProcessingInfo, + dummy_inputs=BeeDummyInputsBuilder, +) +class BeeForConditionalGeneration(LlavaOnevisionForConditionalGeneration): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers + # v4.55 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "model.image_newline": "image_newline", + "lm_head.": "language_model.lm_head.", + } + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config + self.multi_modal_projector = BeeMultiModalProjector(config) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index d9d4c62639d5..1c2334a78543 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Set -from typing import Optional, Union import torch from torch import nn @@ -66,7 +65,7 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: token_type_ids = _decode_token_type_ids(input_ids) @@ -103,9 +102,9 @@ def _head(self, pooled_output: torch.Tensor): def forward( self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], + hidden_states: torch.Tensor | list[torch.Tensor], pooling_metadata: PoolingMetadata, - ) -> Union[torch.Tensor, list[torch.Tensor]]: + ) -> torch.Tensor | list[torch.Tensor]: pooled_output = self.pooling(hidden_states, pooling_metadata) if isinstance(pooled_output, list): @@ -147,8 +146,8 @@ class BertLayer(nn.Module): def __init__( self, config: BertConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -191,8 +190,8 @@ def __init__( hidden_size: int, num_attention_heads: int, layer_norm_eps: float, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -225,8 +224,8 @@ def __init__( self, hidden_size: int, num_attention_heads: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -281,7 +280,7 @@ def __init__( self, hidden_size: int, layer_norm_eps: float, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -308,7 +307,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -333,7 +332,7 @@ def __init__( hidden_size: int, intermediate_size: int, layer_norm_eps: float, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -383,8 +382,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: hidden_states = self.embeddings( input_ids=input_ids, @@ -494,8 +493,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: return self.model( input_ids=input_ids, @@ -522,7 +521,7 @@ def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel: def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: return DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_embed": Pooler.for_token_embed(pooler_config), "embed": Pooler.for_embed(pooler_config), } ) @@ -573,6 +572,220 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: return token_type_ids +class BertMLMHead(nn.Module): + def __init__( + self, hidden_size: int, vocab_size: int, layer_norm_eps: float = 1e-12 + ): + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.GELU() + self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.decoder = nn.Linear(hidden_size, vocab_size, bias=True) + + def tie_weights_with_embeddings(self, embeddings_weight: torch.Tensor): + self.decoder.weight = embeddings_weight + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + x = self.dense(hidden_states) + x = self.activation(x) + x = self.layer_norm(x) + logits = self.decoder(x) + return logits + + +class SPLADESparsePooler(Pooler): + """ + SPLADE sparse pooling: + logits = mlm_head(hidden_states) + -> log1p(relu(logits)) + -> (max|sum over L) + -> [V] + + Padding is masked with an attention mask, + [CLS]/[SEP] is removed (selected), + and then pooled. + """ + + def __init__( + self, + mlm_head: nn.Module, + cls_token_id: int | None = 101, + sep_token_id: int | None = 102, + pooling: str = "max", + remove_cls_sep: bool = True, + ): + super().__init__() + assert pooling in ("max", "sum") + self.mlm_head = mlm_head + self.cls_token_id = cls_token_id + self.sep_token_id = sep_token_id + self.pooling = pooling + self.remove_cls_sep = remove_cls_sep + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"embed"} + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return PoolingParamsUpdate(requires_token_ids=True) + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> torch.Tensor: + assert isinstance(hidden_states, torch.Tensor) and hidden_states.dim() == 2 + + lens_tensor: torch.Tensor = pooling_metadata.prompt_lens + lens: list[int] = lens_tensor.tolist() + B: int = len(lens) + + token_ids = pooling_metadata.prompt_token_ids + offset = 0 + pooled_list: list[torch.Tensor] = [] + + for i in range(B): + L = int(lens[i]) + hs = hidden_states[offset : offset + L] + + start_idx = 0 + end_idx = L + if self.remove_cls_sep and token_ids is not None: + if ( + self.cls_token_id is not None + and token_ids[i, 0].item() == self.cls_token_id + ): + start_idx = 1 + if ( + self.sep_token_id is not None + and token_ids[i, L - 1].item() == self.sep_token_id + ): + end_idx = max(start_idx, L - 1) + + if end_idx <= start_idx: + V = int(self.mlm_head.decoder.out_features) + pooled_list.append(hs.new_zeros((V,))) + offset += L + continue + + logits_i = self.mlm_head(hs[start_idx:end_idx]) + scores_i = torch.log1p(torch.relu(logits_i)) + + if self.pooling == "sum": + pooled_i = scores_i.sum(dim=0) + else: # "max" + pooled_i = scores_i.max(dim=0).values + + pooled_list.append(pooled_i.contiguous()) + offset += L + + return torch.stack(pooled_list, dim=0).contiguous() + + +@default_pooling_type("CLS") +class BertSpladeSparseEmbeddingModel(BertEmbeddingModel): + """ + BertEmbeddingModel + SPLADE sparse embedding. + - Make logits by self.mlm_head + - pooler: SPLADESparsePooler(mlm_head...) + """ + + def __init__( + self, *, vllm_config: VllmConfig, prefix: str = "", splade_pooling: str = "max" + ): + super().__init__(vllm_config=vllm_config, prefix=prefix) + cfg = vllm_config.model_config.hf_config + + # MLM head + self.mlm_head = BertMLMHead( + hidden_size=cfg.hidden_size, + vocab_size=cfg.vocab_size, + layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12), + ) + + self._splade_pooling = splade_pooling + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + self.pooler = self._build_pooler(pooler_config) + + def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: + cfg = self.model.config + + if not hasattr(self, "mlm_head"): + self.mlm_head = BertMLMHead( + hidden_size=cfg.hidden_size, + vocab_size=cfg.vocab_size, + layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12), + ) + + pooling_mode = getattr(self, "_splade_pooling", "max") + + cls_id = getattr(cfg, "cls_token_id", None) + sep_id = getattr(cfg, "sep_token_id", None) + + return DispatchPooler( + { + "token_embed": Pooler.for_token_embed(pooler_config), + "embed": SPLADESparsePooler( + mlm_head=self.mlm_head, + cls_token_id=cls_id, + sep_token_id=sep_id, + pooling=pooling_mode, # "max" or "sum" + remove_cls_sep=True, + ), + } + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + if not hasattr(self, "mlm_head"): + cfg = self.model.config + self.mlm_head = BertMLMHead( + hidden_size=cfg.hidden_size, + vocab_size=cfg.vocab_size, + layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12), + ) + + def _strip(name: str) -> str: + for p in ("model.", "bert."): + if name.startswith(p): + name = name[len(p) :] + return name + + weights_list = list(weights) + model_side: list[tuple[str, torch.Tensor]] = [] + mlm_side: list[tuple[str, torch.Tensor]] = [] + + for k, w in weights_list: + name = _strip(k) + if name.startswith("cls.predictions."): + mlm_side.append((name, w)) + else: + model_side.append((name, w)) + + loaded: set[str] = set() + loaded_model = self.model.load_weights(model_side) + loaded.update({"model." + n for n in loaded_model}) + + if mlm_side: + name_map = { + "cls.predictions.transform.dense.weight": "mlm_head.dense.weight", + "cls.predictions.transform.dense.bias": "mlm_head.dense.bias", + ("cls.predictions.transform.LayerNorm.weight"): ( + "mlm_head.layer_norm.weight" + ), + ("cls.predictions.transform.LayerNorm.bias"): ( + "mlm_head.layer_norm.bias" + ), + "cls.predictions.decoder.weight": "mlm_head.decoder.weight", + "cls.predictions.decoder.bias": "mlm_head.decoder.bias", + } + remapped = [(name_map[n], w) for n, w in mlm_side if n in name_map] + if remapped: + loaded_mlm = AutoWeightsLoader(self).load_weights(remapped) + loaded.update(loaded_mlm) + + return loaded + + @default_pooling_type("CLS") class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant): """A model that uses Bert to provide embedding functionalities. @@ -608,20 +821,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_classify": Pooler.for_token_classify( + pooler_config, classifier=self.classifier + ), "classify": ClassifierPooler( pooling=self.bert.pooler, classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config - ), + act_fn="classify", ), "score": ClassifierPooler( - pooling=self.bert.pooler, - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config - ), + pooling=self.bert.pooler, classifier=self.classifier, act_fn="score" ), } ) @@ -636,11 +845,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, ) -> torch.Tensor: if token_type_ids is not None: assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) @@ -678,7 +887,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_classify": Pooler.for_token_classify( + pooler_config=pooler_config + ), } ) @@ -692,11 +903,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, ) -> torch.Tensor: if token_type_ids is not None: assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 05cb0e22a0aa..31fdc4d21245 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Optional import torch from torch import nn @@ -67,7 +66,7 @@ def __init__(self, config: PretrainedConfig): def forward( self, input_ids: torch.Tensor, - token_type_ids: Optional[torch.Tensor] = None, + token_type_ids: torch.Tensor | None = None, ) -> torch.Tensor: input_shape = input_ids.size() inputs_embeds = self.word_embeddings(input_ids) @@ -91,10 +90,10 @@ def __init__( self, hidden_size: int, num_attention_heads: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, bias: bool = True, - rotary_kwargs: Optional[dict] = None, + rotary_kwargs: dict | None = None, prefix: str = "", ): super().__init__() @@ -166,7 +165,7 @@ def __init__( intermediate_size: int, hidden_act: str, bias: bool = True, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -200,7 +199,7 @@ def __init__( intermediate_size: int, hidden_act: str, bias: bool = True, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -235,8 +234,8 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, + params_dtype: torch.dtype | None = None, + tp_size: int | None = None, ): super().__init__() @@ -344,11 +343,11 @@ class BertWithRopeBlock(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, moe: bool = False, bias: bool = True, - rotary_kwargs: Optional[dict] = None, + rotary_kwargs: dict | None = None, prefix: str = "", ): super().__init__() @@ -406,7 +405,7 @@ def __init__( self, vllm_config: VllmConfig, bias: bool = True, - rotary_kwargs: Optional[dict] = None, + rotary_kwargs: dict | None = None, prefix: str = "", ): super().__init__() @@ -471,9 +470,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, ) -> torch.Tensor: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -696,20 +695,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_classify": Pooler.for_token_classify( + pooler_config, classifier=self.classifier + ), "classify": ClassifierPooler( pooling=self.new.pooler, classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config - ), + act_fn="classify", ), "score": ClassifierPooler( - pooling=self.new.pooler, - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config - ), + pooling=self.new.pooler, classifier=self.classifier, act_fn="score" ), } ) @@ -724,10 +719,10 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: return self.new( input_ids=input_ids, diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index aa361e0a2a39..2e4f73312efa 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -4,7 +4,6 @@ within a vision language model.""" from collections.abc import Iterable -from typing import Optional, Union import torch import torch.nn as nn @@ -38,7 +37,7 @@ def get_blip_num_patches(*, image_size: int, patch_size: int) -> int: # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa class BlipVisionEmbeddings(nn.Module): - def __init__(self, config: Union[BlipVisionConfig, Blip2VisionConfig]): + def __init__(self, config: BlipVisionConfig | Blip2VisionConfig): super().__init__() self.config = config @@ -86,8 +85,8 @@ class BlipAttention(nn.Module): def __init__( self, - config: Union[BlipVisionConfig, Blip2VisionConfig], - quant_config: Optional[QuantizationConfig] = None, + config: BlipVisionConfig | Blip2VisionConfig, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -151,7 +150,7 @@ class BlipMLP(nn.Module): def __init__( self, config: BlipVisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -186,7 +185,7 @@ class BlipEncoderLayer(nn.Module): def __init__( self, config: BlipVisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -228,8 +227,8 @@ class BlipEncoder(nn.Module): def __init__( self, config: BlipVisionConfig, - quant_config: Optional[QuantizationConfig] = None, - num_hidden_layers_override: Optional[int] = None, + quant_config: QuantizationConfig | None = None, + num_hidden_layers_override: int | None = None, prefix: str = "", ) -> None: super().__init__() @@ -268,10 +267,10 @@ class BlipVisionModel(nn.Module, SupportsQuant): def __init__( self, config: BlipVisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, - num_hidden_layers_override: Optional[int] = None, - require_post_norm: Optional[bool] = None, + num_hidden_layers_override: int | None = None, + require_post_norm: bool | None = None, prefix: str = "", ) -> None: super().__init__() diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 8e94d5935026..2986a72f2e48 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal, TypeAlias import torch import torch.nn as nn @@ -70,7 +70,7 @@ class Blip2ImageEmbeddingInputs(TensorSchema): data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")] -Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs] +Blip2ImageInputs: TypeAlias = Blip2ImagePixelInputs | Blip2ImageEmbeddingInputs class Blip2QFormerMultiHeadAttention(nn.Module): @@ -78,8 +78,8 @@ def __init__( self, config: Blip2QFormerConfig, *, - quant_config: Optional[QuantizationConfig], - cache_config: Optional[CacheConfig], + quant_config: QuantizationConfig | None, + cache_config: CacheConfig | None, is_cross_attention: bool = False, prefix: str = "", ) -> None: @@ -123,7 +123,7 @@ def transpose_for_scores(self, x): def forward( self, hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: torch.FloatTensor | None = None, ): is_cross_attention = encoder_hidden_states is not None @@ -179,8 +179,8 @@ def __init__( self, config: Blip2QFormerConfig, *, - quant_config: Optional[QuantizationConfig], - cache_config: Optional[CacheConfig], + quant_config: QuantizationConfig | None, + cache_config: CacheConfig | None, is_cross_attention: bool = False, prefix: str = "", ) -> None: @@ -199,7 +199,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: torch.FloatTensor | None = None, ) -> tuple[torch.Tensor]: self_output = self.attention( hidden_states, @@ -247,8 +247,8 @@ def __init__( self, config: Blip2QFormerConfig, *, - quant_config: Optional[QuantizationConfig], - cache_config: Optional[CacheConfig], + quant_config: QuantizationConfig | None, + cache_config: CacheConfig | None, layer_idx: int, prefix: str = "", ) -> None: @@ -340,8 +340,8 @@ def __init__( self, config: Blip2QFormerConfig, *, - quant_config: Optional[QuantizationConfig], - cache_config: Optional[CacheConfig], + quant_config: QuantizationConfig | None, + cache_config: CacheConfig | None, prefix: str = "", ) -> None: super().__init__() @@ -385,8 +385,8 @@ def __init__( self, config: Blip2QFormerConfig, *, - quant_config: Optional[QuantizationConfig], - cache_config: Optional[CacheConfig], + quant_config: QuantizationConfig | None, + cache_config: CacheConfig | None, prefix: str = "", ) -> None: super().__init__() @@ -426,7 +426,7 @@ class Blip2ProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(Blip2Config) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": 1} def get_num_image_tokens(self) -> int: @@ -442,7 +442,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: hf_config = self.info.get_hf_config() vision_config = hf_config.vision_config @@ -526,7 +526,7 @@ class Blip2ForConditionalGeneration( merge_by_field_config = True @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return None @@ -573,7 +573,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[Blip2ImageInputs]: + ) -> Blip2ImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -641,8 +641,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> IntermediateTensors: """Run forward pass for BLIP-2. @@ -687,7 +687,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 4a814fc4020d..bbbd14adf92b 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -22,7 +22,6 @@ import math from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -89,8 +88,8 @@ class BloomAttention(nn.Module): def __init__( self, config: BloomConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -152,7 +151,7 @@ class BloomMLP(nn.Module): def __init__( self, config: BloomConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ): super().__init__() hidden_size = config.hidden_size @@ -179,8 +178,8 @@ class BloomBlock(nn.Module): def __init__( self, config: BloomConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -274,9 +273,9 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -356,9 +355,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.transformer( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -367,7 +366,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index d8756e236f4c..6f7e18d78bad 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -4,7 +4,7 @@ from collections.abc import Iterable, Mapping, Sequence from functools import cached_property from itertools import islice -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal import torch import torch.nn as nn @@ -94,7 +94,7 @@ def get_hf_config(self): def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor(ChameleonProcessor, **kwargs) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": 1} def get_num_image_tokens(self) -> int: @@ -115,7 +115,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: config = self.info.get_hf_config() @@ -225,7 +225,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, ) -> None: super().__init__() @@ -262,11 +262,11 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 4096, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, - cache_config: Optional[CacheConfig] = None, + cache_config: CacheConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -357,8 +357,8 @@ class ChameleonDecoderLayer(nn.Module): def __init__( self, config: ChameleonConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -403,8 +403,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -426,8 +426,8 @@ class ChameleonSwinDecoderLayer(nn.Module): def __init__( self, config: ChameleonConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -472,7 +472,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: residual = hidden_states hidden_states = self.self_attn( @@ -896,11 +896,11 @@ def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -941,7 +941,7 @@ class ChameleonForConditionalGeneration( } @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<image>" @@ -975,7 +975,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[ChameleonImagePixelInputs]: + ) -> ChameleonImagePixelInputs | None: pixel_values = kwargs.pop("pixel_values", None) if pixel_values is None: @@ -999,7 +999,7 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: return [] assert self.model.vqmodel is not None image_tokens = self.model.get_image_tokens( - image_input["data"].to(self.config.torch_dtype) + image_input["data"].to(self.config.dtype) ) vision_embeddings = self.model.get_input_embeddings(image_tokens) return vision_embeddings @@ -1008,10 +1008,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None @@ -1023,7 +1023,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) # Disallow image tokens which does not include special diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index ece719df61f7..bcbe82b78c3b 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -7,7 +7,6 @@ import json from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -50,8 +49,8 @@ class GLMAttention(nn.Module): def __init__( self, config: ChatGLMConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -144,7 +143,7 @@ class GLMMLP(nn.Module): def __init__( self, config: ChatGLMConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -190,8 +189,8 @@ class GLMBlock(nn.Module): def __init__( self, config: ChatGLMConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -263,8 +262,8 @@ class GLMTransformer(nn.Module): def __init__( self, config: ChatGLMConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -295,7 +294,7 @@ def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer( hidden_states=hidden_states, position_ids=position_ids @@ -361,10 +360,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -459,7 +458,7 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits @@ -491,9 +490,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.transformer( input_ids, positions, intermediate_tensors, inputs_embeds ) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index f05d5c4cc1d8..27953c27188d 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal import torch import torch.nn as nn @@ -125,7 +125,7 @@ def get_vision_encoder_info(self): def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor(CLIPProcessor, **kwargs) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": 1} def get_num_image_tokens( @@ -169,7 +169,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) @@ -199,12 +199,12 @@ def image_token_id(self) -> int: def apply( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Optional[Mapping[str, object]] = None, + tokenization_kwargs: Mapping[str, object] | None = None, *, - mm_uuids: Optional[MultiModalUUIDDict] = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> MultiModalInputs: if prompt and mm_data: raise ValueError( @@ -286,9 +286,9 @@ def __init__(self, config: CLIPTextConfig): def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, position_ids: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if inputs_embeds is None: if input_ids is None: @@ -350,11 +350,11 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: class CLIPAttention(nn.Module): def __init__( self, - config: Union[CLIPTextConfig, CLIPVisionConfig], - quant_config: Optional[QuantizationConfig] = None, + config: CLIPTextConfig | CLIPVisionConfig, + quant_config: QuantizationConfig | None = None, *, prefix: str = "", - attn_cls: Union[type[Attention], type[MultiHeadAttention]], + attn_cls: type[Attention] | type[MultiHeadAttention], ) -> None: super().__init__() @@ -412,8 +412,8 @@ def forward( class CLIPMLP(nn.Module): def __init__( self, - config: Union[CLIPTextConfig, CLIPVisionConfig], - quant_config: Optional[QuantizationConfig] = None, + config: CLIPTextConfig | CLIPVisionConfig, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -445,11 +445,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class CLIPEncoderLayer(nn.Module): def __init__( self, - config: Union[CLIPTextConfig, CLIPVisionConfig], - quant_config: Optional[QuantizationConfig] = None, + config: CLIPTextConfig | CLIPVisionConfig, + quant_config: QuantizationConfig | None = None, *, prefix: str = "", - attn_cls: Union[type[Attention], type[MultiHeadAttention]], + attn_cls: type[Attention] | type[MultiHeadAttention], ) -> None: super().__init__() self.self_attn = CLIPAttention( @@ -488,12 +488,12 @@ class CLIPEncoder(nn.Module): def __init__( self, - config: Union[CLIPTextConfig, CLIPVisionConfig], - quant_config: Optional[QuantizationConfig] = None, - num_hidden_layers_override: Optional[int] = None, + config: CLIPTextConfig | CLIPVisionConfig, + quant_config: QuantizationConfig | None = None, + num_hidden_layers_override: int | None = None, *, prefix: str = "", - attn_cls: Union[type[Attention], type[MultiHeadAttention]], + attn_cls: type[Attention] | type[MultiHeadAttention], ) -> None: super().__init__() @@ -519,7 +519,7 @@ def forward( self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool, - ) -> Union[torch.Tensor, list[torch.Tensor]]: + ) -> torch.Tensor | list[torch.Tensor]: hidden_states_pool = [inputs_embeds] hidden_states = inputs_embeds @@ -538,7 +538,7 @@ class CLIPTextTransformer(nn.Module): def __init__( self, config: CLIPTextConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, prefix: str = "", ) -> None: @@ -566,9 +566,9 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, position_ids: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: hidden_states = self.embeddings( input_ids=input_ids, @@ -616,10 +616,10 @@ class CLIPVisionTransformer(nn.Module): def __init__( self, config: CLIPVisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, - num_hidden_layers_override: Optional[int] = None, - require_post_norm: Optional[bool] = None, + num_hidden_layers_override: int | None = None, + require_post_norm: bool | None = None, prefix: str = "", ) -> None: super().__init__() @@ -669,8 +669,8 @@ def forward( self, pixel_values: torch.Tensor, *, - select_layers: Optional[list[int]] = None, - feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, + select_layers: list[int] | None = None, + feature_select_strategy: VisionFeatureSelectStrategy | None = None, ) -> torch.Tensor: hidden_states = self.embeddings(pixel_values) hidden_states = self.pre_layrnorm(hidden_states) @@ -736,10 +736,10 @@ class CLIPVisionModel(nn.Module): def __init__( self, config: CLIPVisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, - num_hidden_layers_override: Optional[int] = None, - require_post_norm: Optional[bool] = None, + num_hidden_layers_override: int | None = None, + require_post_norm: bool | None = None, prefix: str = "", ) -> None: super().__init__() @@ -755,8 +755,8 @@ def __init__( def forward( self, pixel_values: torch.Tensor, - select_layers: Optional[list[int]] = None, - feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, + select_layers: list[int] | None = None, + feature_select_strategy: VisionFeatureSelectStrategy | None = None, ) -> torch.Tensor: return self.vision_model( pixel_values, @@ -787,7 +787,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): merge_by_field_config = True @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return None @@ -837,7 +837,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_embed": Pooler.for_token_embed(pooler_config), "embed": Pooler.for_embed(pooler_config), } ) @@ -847,9 +847,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def get_text_features( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, position_ids: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: pooled_output = self.text_model( input_ids=input_ids, @@ -864,7 +864,7 @@ def get_text_features( def get_image_features( self, pixel_values: torch.Tensor, - feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, + feature_select_strategy: VisionFeatureSelectStrategy | None = None, ) -> torch.Tensor: if feature_select_strategy is None: feature_select_strategy = _get_vision_feature_select_strategy( @@ -883,7 +883,7 @@ def get_image_features( def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[CLIPImagePixelInputs]: + ) -> CLIPImagePixelInputs | None: pixel_values = kwargs.pop("pixel_values", None) if pixel_values is None: return None @@ -906,9 +906,9 @@ def get_language_model(self) -> torch.nn.Module: def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, *, - is_multimodal: Optional[torch.Tensor] = None, + is_multimodal: torch.Tensor | None = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: self._is_text_input = ( @@ -936,10 +936,10 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor: if intermediate_tensors is not None: diff --git a/vllm/model_executor/models/cohere2_vision.py b/vllm/model_executor/models/cohere2_vision.py index 73aafbd01144..19cc31c9bd18 100644 --- a/vllm/model_executor/models/cohere2_vision.py +++ b/vllm/model_executor/models/cohere2_vision.py @@ -4,7 +4,7 @@ """Command-A-Vision (Cohere2Vision) multimodal model implementation for vLLM.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal import torch from torch import nn @@ -148,7 +148,7 @@ def get_hf_processor(self, **kwargs: object) -> Cohere2VisionProcessor: def get_image_processor(self, **kwargs: object): return self.get_hf_processor(**kwargs).image_processor - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_image_size_with_most_features(self) -> ImageSize: @@ -163,7 +163,7 @@ def get_num_patches( *, image_width: int, image_height: int, - processor: Optional[Cohere2VisionProcessor], + processor: Cohere2VisionProcessor | None, ) -> int: """ Calculate the number of image patches for a given image. @@ -217,7 +217,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) image_size = self.info.get_image_size_with_most_features() @@ -404,7 +404,7 @@ def _process_image_input( def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[Cohere2VisionImagePixelInputs]: + ) -> Cohere2VisionImagePixelInputs | None: pixel_values = kwargs.pop("pixel_values", None) num_patches = kwargs.pop("num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -450,10 +450,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None @@ -468,5 +468,5 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index e38c3c0492fb..75459601f76b 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -25,7 +25,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -94,8 +93,8 @@ def forward(self, hidden_states, residuals=None): class CohereMLP(nn.Module): def __init__( self, - config: Union[CohereConfig, Cohere2Config], - quant_config: Optional[QuantizationConfig] = None, + config: CohereConfig | Cohere2Config, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -128,9 +127,9 @@ def forward(self, x): class CohereAttention(nn.Module): def __init__( self, - config: Union[CohereConfig, Cohere2Config], - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + config: CohereConfig | Cohere2Config, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -241,9 +240,9 @@ def forward( class CohereDecoderLayer(nn.Module): def __init__( self, - config: Union[CohereConfig, Cohere2Config], - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + config: CohereConfig | Cohere2Config, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -265,7 +264,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states @@ -324,9 +323,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -452,9 +451,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -463,7 +462,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: is_not_lora = hasattr(self.model.embed_tokens, "weight") if is_not_lora: logits = self.logits_processor(self.model.embed_tokens, hidden_states) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index caf481f5aec6..d4367be1c785 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -6,7 +6,8 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.utils import cdiv, round_up +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec if TYPE_CHECKING: @@ -59,16 +60,26 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: class JinaRobertaModelConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: - config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + config = model_config.hf_config if config.position_embedding_type == "rotary": assert config.__class__.__name__ == "XLMRobertaFlashConfig" head_dim = config.hidden_size // config.num_attention_heads + max_position = config.max_position_embeddings + # Jina-embeddings-v3 has max_position_embeddings=8194, which will cause + # out-of-bound index issue at RoPE for long prompts with torch.compile, + # because it can't be divided by triton num_warps(default=4 or 8). + # To deal with this, we increase max_position to multiple of n_warps, + # so that triton kernel won't hit out-of-bound index in RoPE cache. + if not model_config.enforce_eager: + max_position = round_up(max_position, 8) + config.rotary_kwargs = { "head_size": head_dim, "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), - "max_position": config.max_position_embeddings, + "max_position": max_position, "base": getattr(config, "rope_theta", config.rotary_emb_base), "rope_scaling": getattr(config, "rope_scaling", None), } @@ -248,21 +259,19 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: # Increase the max capture size from 512 to 992 for performance. # NOTE(woosuk): This will increase the number of CUDA graphs # from 67 to 81. - scheduler_config = vllm_config.scheduler_config - if len(scheduler_config.cuda_graph_sizes) == 1: - max_capture_size = scheduler_config.cuda_graph_sizes[0] + compilation_config = vllm_config.compilation_config + # Only override when the user has not set either of + # cudagraph_capture_sizes or max_cudagraph_capture_size. + if ( + compilation_config.cudagraph_capture_sizes is None + and compilation_config.max_cudagraph_capture_size is None + ): # FIXME(woosuk): When using full cuda graph with FA3, the max # supported size is 992. - if max_capture_size < 992: - cuda_graph_sizes = [1, 2, 4] - # Step size 8 for small batch sizes - cuda_graph_sizes += [i for i in range(8, 256, 8)] - # Step size 16 for larger batch sizes - cuda_graph_sizes += [i for i in range(256, 993, 16)] - scheduler_config.cuda_graph_sizes = cuda_graph_sizes - logger.info( - "Overriding max cuda graph capture size to %d for performance.", 992 - ) + compilation_config.max_cudagraph_capture_size = 992 + logger.info( + "Overriding max cuda graph capture size to %d for performance.", 992 + ) class MambaModelConfig(VerifyAndUpdateConfig): @@ -365,6 +374,23 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: block_size=model_config.max_model_len, ).page_size_bytes + # Model may be marked as is_hybrid + # but mamba is skipped via config, + # return directly + if mamba_page_size == 0: + return + + # Attention backend constraints: + # - FlashAttention (FA) requires block size to be multiple of 16 + # - MLA (Multi-head Latent Attention) requires larger alignment: + # * CUTLASS_MLA backend: 128-byte alignment + # * Other MLA backends: 64-byte alignment + if model_config.use_mla: + use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" + kernel_block_alignment_size = 128 if use_cutlass_mla else 64 + else: + kernel_block_alignment_size = 16 + if cache_config.enable_prefix_caching: # With prefix caching, select attention block size to # optimize for mamba kernel performance @@ -381,19 +407,28 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: # TODO(tdoublep): this constraint can be relaxed fairly # easily by changing the way we layout chunks in the # mamba2 kernels. - chunk_size = model_config.get_mamba_chunk_size() + + from math import gcd + + def lcm(a, b): + return a * b // gcd(a, b) + + base_chunk_size = model_config.get_mamba_chunk_size() attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) + + chunk_size = lcm(base_chunk_size, kernel_block_alignment_size) attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size) cache_config.mamba_block_size = attn_block_size else: # Without prefix caching, select minimum valid attention block size # to minimize mamba state padding - # some attention backends (e.g. FA) only support setting - # block size to multiple of 16, so let's suggest a value - # that would work (note: FA is currently not compatible - # with mamba layers, use FlashInfer instead). - attn_block_size = 16 * cdiv(mamba_page_size, 16 * attn_page_size_1_token) + # Calculate minimum attention block size that satisfies both: + # 1. Backend alignment requirements (kernel_block_alignment_size) + # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size) + attn_block_size = kernel_block_alignment_size * cdiv( + mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token + ) # override attention block size if either (a) the # user has not set it or (b) the user has set it @@ -444,12 +479,9 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: is_v32 = hasattr(hf_config, "index_topk") assert is_v32 - # For DeepSeekV3.2, we use a custom fp8 format as default (i.e. - # "auto") + # For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled. cache_config = vllm_config.cache_config - if cache_config.cache_dtype == "auto" or cache_config.cache_dtype.startswith( - "fp8" - ): + if cache_config.cache_dtype.startswith("fp8"): cache_config.cache_dtype = "fp8_ds_mla" logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2") if cache_config.cache_dtype == "bfloat16": diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 8ec7a82a7b2a..088960e06448 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -3,7 +3,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch import torch.nn as nn @@ -54,7 +53,7 @@ class DbrxRouter(nn.Module): def __init__( self, config: DbrxConfig, - params_dtype: Optional[torch.dtype] = None, + params_dtype: torch.dtype | None = None, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -77,8 +76,8 @@ class DbrxExperts(FusedMoE): def __init__( self, config: DbrxConfig, - quant_config: Optional[QuantizationConfig] = None, - params_dtype: Optional[torch.dtype] = None, + quant_config: QuantizationConfig | None = None, + params_dtype: torch.dtype | None = None, prefix: str = "", ): super().__init__( @@ -157,8 +156,8 @@ class DbrxMoE(nn.Module): def __init__( self, config: DbrxConfig, - quant_config: Optional[QuantizationConfig] = None, - params_dtype: Optional[torch.dtype] = None, + quant_config: QuantizationConfig | None = None, + params_dtype: torch.dtype | None = None, prefix: str = "", ): super().__init__() @@ -189,8 +188,8 @@ class DbrxAttention(nn.Module): def __init__( self, config: DbrxConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -270,8 +269,8 @@ class DbrxFusedNormAttention(nn.Module): def __init__( self, config: DbrxConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -303,8 +302,8 @@ class DbrxBlock(nn.Module): def __init__( self, config: DbrxConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -361,9 +360,9 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -466,9 +465,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.transformer( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -477,7 +476,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/deepencoder.py b/vllm/model_executor/models/deepencoder.py new file mode 100644 index 000000000000..e62a57eccc95 --- /dev/null +++ b/vllm/model_executor/models/deepencoder.py @@ -0,0 +1,673 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# adapted from +# https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek-OCR-master/DeepSeek-OCR-vllm/deepencoder/sam_vary_sdpa.py + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import math +from collections.abc import Iterable +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import CLIPVisionConfig + +from vllm.attention.layer import MultiHeadAttention +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from .clip import CLIPEncoder, CLIPVisionEmbeddings + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: type[nn.Module] = nn.LayerNorm, + act_layer: type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ # noqa: E501 + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: nn.Parameter | None = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros( + 1, img_size // patch_size, img_size // patch_size, embed_dim + ) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) + self.net_3 = nn.Conv2d( + 512, 1024, kernel_size=3, stride=2, padding=1, bias=False + ) + + def get_abs_pos(self, abs_pos: torch.Tensor, tgt_size: int): + dtype = abs_pos.dtype + + src_size = abs_pos.size(1) + + if src_size != tgt_size: + old_pos_embed = abs_pos.permute(0, 3, 1, 2) + old_pos_embed = old_pos_embed.to(torch.float32) + new_pos_embed = F.interpolate( + old_pos_embed, + size=(tgt_size, tgt_size), + mode="bicubic", + antialias=True, + align_corners=False, + ).to(dtype) + new_pos_embed = new_pos_embed.permute(0, 2, 3, 1) + return new_pos_embed + else: + return abs_pos + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.get_abs_pos(self.pos_embed, x.size(1)) + + for blk in self.blocks: + x = blk(x) + + neck_output = self.neck(x.permute(0, 3, 1, 2)) + conv2_output = self.net_2(neck_output) + conv3_output = self.net_3(conv2_output) + + return conv3_output + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation + blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: type[nn.Module] = nn.LayerNorm, + act_layer: type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: tuple[int, int] | None = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ # noqa: E501 + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = RelPosAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock( + embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer + ) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class RelPosAttention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: tuple[int, int] | None = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ # noqa: E501 + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert input_size is not None, ( + "Input size must be provided if using relative positional encoding." + ) + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = ( + self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + rel_h, rel_w = None, None + if self.use_rel_pos: + rel_h, rel_w = add_decomposed_rel_pos( + q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W) + ) + + q = q.view(B, self.num_heads, H * W, -1) + k = k.view(B, self.num_heads, H * W, -1) + v = v.view(B, self.num_heads, H * W, -1) + + if self.use_rel_pos: + rel_h = rel_h.view( + B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3) + ) + rel_w = rel_w.view( + B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3) + ) + attn_bias = (rel_h + rel_w).view( + B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4) + ) + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_bias + ) + else: + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + + x = ( + x.view(B, self.num_heads, H, W, -1) + .permute(0, 2, 3, 1, 4) + .reshape(B, H, W, -1) + ) + + x = self.proj(x) + + return x + + +def window_partition( + x: torch.Tensor, window_size: int +) -> tuple[torch.Tensor, tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ # noqa: E501 + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, + window_size: int, + pad_hw: tuple[int, int], + hw: tuple[int, int], +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ # noqa: E501 + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + dtype = rel_pos.dtype + rel_pos = rel_pos.to(torch.float32) + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ).to(dtype) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max( + k_size / q_size, 1.0 + ) + k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max( + q_size / k_size, 1.0 + ) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: tuple[int, int], + k_size: tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + Args: + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ # noqa: E501 + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + rel_h = rel_h.unsqueeze(-1) + rel_w = rel_w.unsqueeze(-2) + rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1) + rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w) + + return rel_h, rel_w + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: tuple[int, int] = (16, 16), + stride: tuple[int, int] = (16, 16), + padding: tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x + + +# TODO(Isotr0py): use vision_config to build sam model +def build_sam_vit_b(): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + ) + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_encoder = ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ) + return image_encoder + + +class DeepCLIPVisionEmbeddings(CLIPVisionEmbeddings): + def get_abs_pos(self, abs_pos: torch.Tensor, tgt_size: int): + # abs_pos: L, C + # tgt_size: M + # return: M, C + dim = abs_pos.size(-1) + abs_pos_new = abs_pos.squeeze(0) + cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:] + + src_size = int(math.sqrt(abs_pos_new.shape[0] - 1)) + tgt_size = int(math.sqrt(tgt_size)) + dtype = abs_pos.dtype + + if src_size != tgt_size: + old_pos_embed = ( + old_pos_embed.view(1, src_size, src_size, dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + old_pos_embed = old_pos_embed.to(torch.float32) + new_pos_embed = F.interpolate( + old_pos_embed, + size=(tgt_size, tgt_size), + mode="bicubic", + antialias=True, + align_corners=False, + ).to(dtype) + new_pos_embed = new_pos_embed.permute(0, 2, 3, 1) + new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim) + vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0) + vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim) + return vision_pos_embed + else: + return abs_pos + + def forward( + self, pixel_values: torch.Tensor, patch_embeds: torch.Tensor | None = None + ) -> torch.Tensor: + batch_size = pixel_values.shape[0] + if patch_embeds is not None: + patch_embeds = patch_embeds + else: + patch_embeds = self.patch_embedding(pixel_values) + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.get_abs_pos( + self.position_embedding(self.position_ids), embeddings.size(1) + ) + return embeddings + + +class DeepCLIPVisionTransformer(nn.Module): + def __init__( + self, + config: CLIPVisionConfig, + quant_config: QuantizationConfig | None = None, + *, + num_hidden_layers_override: int | None = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + embed_dim = config.hidden_size + + self.embeddings = DeepCLIPVisionEmbeddings(config) + + # NOTE: This typo of "layrnorm" is not fixed on purpose to match + # the original transformers code and name of the model weights. + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.transformer = CLIPEncoder( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + prefix=f"{prefix}.encoder", + attn_cls=MultiHeadAttention, + ) + + num_hidden_layers = config.num_hidden_layers + if len(self.transformer.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {num_hidden_layers} " + f"layers, but you requested {len(self.transformer.layers)} layers." + ) + + @property + def dtype(self): + return next(self.parameters()).dtype + + @property + def device(self): + return next(self.parameters()).device + + def forward( + self, + pixel_values: torch.Tensor, + patch_embeds: torch.Tensor | None = None, + *, + select_layers: list[int] | None = None, + ) -> torch.Tensor: + hidden_states = self.embeddings(pixel_values, patch_embeds) + hidden_states = self.pre_layrnorm(hidden_states) + + # Produces either the last layer output or all of the hidden states, + # depending on if we have select_layers or not + encoder_outputs = self.transformer( + inputs_embeds=hidden_states, + return_all_hidden_states=select_layers is not None, + ) + return encoder_outputs + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 67258c2f77b8..ac934abea45d 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -26,7 +26,7 @@ from collections.abc import Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -76,7 +76,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, reduce_results: bool = True, prefix: str = "", ) -> None: @@ -108,7 +108,7 @@ class DeepseekMoE(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -203,10 +203,10 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -282,8 +282,8 @@ class DeepseekDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -328,7 +328,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> torch.Tensor: # Self Attention if residual is None: @@ -382,9 +382,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -489,9 +489,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -500,7 +500,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index faa7edd4bc3c..107b1e1a0582 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn @@ -224,7 +223,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if inputs_embeds is not None: raise NotImplementedError( @@ -235,7 +234,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 041dd6db7325..aa176ef05fcc 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn from transformers import PretrainedConfig +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -16,10 +16,17 @@ ParallelLMHead, VocabParallelEmbedding, ) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from .deepseek_v2 import DeepseekV2DecoderLayer, get_spec_layer_idx_from_weight_name +from .deepseek_v2 import ( + DeepseekV2DecoderLayer, + get_spec_layer_idx_from_weight_name, +) from .interfaces import SupportsPP from .utils import maybe_prefix @@ -29,7 +36,7 @@ def __init__( self, config: PretrainedConfig, prefix: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -56,6 +63,8 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) + self.device = current_platform.device_type + self.is_v32 = hasattr(config, "index_topk") if self.is_v32: topk_tokens = config.index_topk @@ -63,7 +72,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: vllm_config.scheduler_config.max_num_batched_tokens, topk_tokens, dtype=torch.int32, - device="cuda", + device=self.device, ) else: topk_indices_buffer = None @@ -83,7 +92,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, previous_hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, spec_step_index: int = 0, ) -> torch.Tensor: assert inputs_embeds is not None @@ -135,7 +144,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, previous_hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, spec_step_idx: int = 0, ) -> torch.Tensor: if inputs_embeds is None: @@ -162,6 +171,7 @@ def compute_logits( return logits +@support_torch_compile class DeepSeekMTP(nn.Module, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -178,8 +188,8 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, spec_step_idx: int = 0, ) -> torch.Tensor: hidden_states = self.model( @@ -191,7 +201,7 @@ def compute_logits( self, hidden_states: torch.Tensor, spec_step_idx: int = 0, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.model.compute_logits(hidden_states, spec_step_idx) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -271,6 +281,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if name.endswith(".bias") and name not in params_dict: continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + # According to DeepSeek-V3 Technical Report, MTP modules # shares embedding layer. We only load the first weights. if ( diff --git a/vllm/model_executor/models/deepseek_ocr.py b/vllm/model_executor/models/deepseek_ocr.py new file mode 100644 index 000000000000..fa24db456af4 --- /dev/null +++ b/vllm/model_executor/models/deepseek_ocr.py @@ -0,0 +1,597 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Deepseek-OCR model compatible with HuggingFace weights.""" + +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Literal + +import torch +import torch.nn as nn +from transformers import BatchFeature, CLIPVisionConfig + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, + SupportsMultiModal, + SupportsPP, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargs, + NestedTensors, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sampling_params import SamplingParams +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config +from vllm.transformers_utils.processors.deepseek_ocr import ( + BASE_SIZE, + CROP_MODE, + IMAGE_SIZE, + DeepseekOCRProcessor, + count_tiles, +) +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.utils.tensor_schema import TensorSchema, TensorShape +from vllm.v1.sample.logits_processor import ( + AdapterLogitsProcessor, + RequestLogitsProcessor, +) + +from .deepencoder import DeepCLIPVisionTransformer, build_sam_vit_b +from .deepseek_vl2 import MlpProjector + +# The image token id may be various +_IMAGE_TOKEN = "<image>" + + +class DeepseekOCRImagePixelInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - n: Number of images + - p: Number of patches + - base_size: Base size of the processor + - image_size: Image size of the processor + """ + + type: Literal["pixel_values"] + data: Annotated[ + torch.Tensor, + TensorShape("bn", 3, "base_size", "base_size", dynamic_dims={"bnp"}), + ] + images_crop: Annotated[ + torch.Tensor, + TensorShape("bnp", 3, "image_size", "image_size", dynamic_dims={"bnp"}), + ] + images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)] + + +class NoRepeatNGramLogitsProcessor: + def __init__( + self, + ngram_size: int, + window_size: int, + whitelist_token_ids: set[int] | None = None, + ): + self.ngram_size = ngram_size + self.window_size = window_size + self.whitelist_token_ids = whitelist_token_ids or set() + + def __call__( + self, + output_ids: list[int], + logits: torch.Tensor, + ) -> torch.Tensor: + if len(output_ids) < self.ngram_size: + return logits + + current_prefix = tuple(output_ids[-(self.ngram_size - 1) :]) + + search_start = max(0, len(output_ids) - self.window_size) + search_end = len(output_ids) - self.ngram_size + 1 + + banned_tokens = set() + for i in range(search_start, search_end): + ngram = tuple(output_ids[i : i + self.ngram_size]) + if ngram[:-1] == current_prefix: + banned_tokens.add(ngram[-1]) + + banned_tokens = banned_tokens - self.whitelist_token_ids + + if banned_tokens: + logits[list(banned_tokens)] = -float("inf") + + return logits + + +class NGramPerReqLogitsProcessor(AdapterLogitsProcessor): + """Example of overriding the wrapper class `__init__()` in order to utilize + info about the device type""" + + def __init__( + self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool + ): + super().__init__(vllm_config, device, is_pin_memory) + + def is_argmax_invariant(self) -> bool: + return True + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> RequestLogitsProcessor | None: + ngram_size = params.extra_args and params.extra_args.get("ngram_size") + window_size = params.extra_args and params.extra_args.get("window_size", 100) + whitelist_token_ids = params.extra_args and params.extra_args.get( + "whitelist_token_ids", None + ) + if ngram_size is None: + return None + if not isinstance(ngram_size, int) or ngram_size <= 0: + raise ValueError( + f"`ngram_size` has to be a strictly positive integer, got {ngram_size}." + ) + if not isinstance(window_size, int) or window_size <= 0: + raise ValueError( + "`window_size` has to be a strictly positive integer, " + f"got {window_size}." + ) + if whitelist_token_ids is not None and not isinstance( + whitelist_token_ids, Iterable + ): + raise ValueError( + "`whitelist_token_ids` has to be a set of integers, " + f"got {whitelist_token_ids}." + ) + else: + whitelist_token_ids = ( + set(whitelist_token_ids) if whitelist_token_ids else None + ) + return NoRepeatNGramLogitsProcessor( + ngram_size=ngram_size, + window_size=window_size, + whitelist_token_ids=whitelist_token_ids, + ) + + +class DeepseekOCRProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(DeepseekVLV2Config) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(DeepseekOCRProcessor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None} + + def get_num_image_tokens( + self, *, image_width: int, image_height: int, cropping: bool = True + ) -> int: + image_size = IMAGE_SIZE + base_size = BASE_SIZE + patch_size = 16 + downsample_ratio = 4 + + if CROP_MODE: + if image_width <= 640 and image_height <= 640: + crop_ratio = [1, 1] + else: + # find the closest aspect ratio to the target + crop_ratio = count_tiles( + image_width, image_height, image_size=IMAGE_SIZE + ) + + num_width_tiles, num_height_tiles = crop_ratio + else: + num_width_tiles = num_height_tiles = 1 + + h = w = math.ceil((base_size // patch_size) / downsample_ratio) + + h2 = w2 = math.ceil((image_size // patch_size) / downsample_ratio) + + global_views_tokens = h * (w + 1) + if num_width_tiles > 1 or num_height_tiles > 1: + local_views_tokens = (num_height_tiles * h2) * (num_width_tiles * w2 + 1) + else: + local_views_tokens = 0 + + return global_views_tokens + local_views_tokens + 1 + + def get_image_size_with_most_features(self) -> ImageSize: + if IMAGE_SIZE == 1024 and BASE_SIZE == 1280: + return ImageSize(width=1024 * 2, height=1024 * 2) + return ImageSize(width=640 * 2, height=640 * 2) + + +class DeepseekOCRDummyInputsBuilder(BaseDummyInputsBuilder[DeepseekOCRProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + max_image_size = self.info.get_image_size_with_most_features() + + return { + "image": self._get_dummy_images( + width=max_image_size.width, + height=max_image_size.height, + num_images=num_images, + ) + } + + +class DeepseekOCRMultiModalProcessor( + BaseMultiModalProcessor[DeepseekOCRProcessingInfo] +): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + processed_outputs = self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(prompt=prompt, **mm_data), + mm_kwargs, + ) + + else: + tokenizer = self.info.get_tokenizer() + processed_outputs = tokenizer( + prompt, add_special_tokens=True, return_tensors="pt" + ) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + images_spatial_crop = hf_inputs.get("images_spatial_crop", torch.empty((0, 2))) + is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1) + patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0) + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + images_spatial_crop=MultiModalFieldConfig.batched("image"), + images_crop=MultiModalFieldConfig.flat_from_sizes( + "image", patches_per_image + ), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + image_token_id = hf_processor.image_token_id + assert isinstance(image_token_id, int) + + def get_replacement_deepseek_vl2(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + size = images.get_image_size(item_idx) + + num_image_tokens = self.info.get_num_image_tokens( + image_width=size.width, + image_height=size.height, + cropping=CROP_MODE, + ) + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement_deepseek_vl2, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + DeepseekOCRMultiModalProcessor, + info=DeepseekOCRProcessingInfo, + dummy_inputs=DeepseekOCRDummyInputsBuilder, +) +class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # map prefix for language backbone + "model.embed_tokens.": "language_model.model.embed_tokens.", + "model.layers.": "language_model.model.layers.", + "model.norm.": "language_model.model.norm.", + "lm_head.": "language_model.lm_head.", + # remove "model." prefix for other components + "model.": "", + } + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return "<image>" + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: DeepseekVLV2Config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.vision_config = config.vision_config + self.projector_config = config.projector_config + self.text_config = config.text_config + + model_config = vllm_config.model_config + tokenizer = cached_tokenizer_from_config(model_config) + self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN] + + self.sam_model = build_sam_vit_b() + clip_vision_config = CLIPVisionConfig( + hidden_size=1024, + intermediate_size=4096, + num_attention_heads=16, + num_hidden_layers=24, + image_size=224, + patch_size=14, + projection_dim=512, + layer_norm_eps=1e-5, + ) + self.vision_model = DeepCLIPVisionTransformer( + config=clip_vision_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_model"), + ) + + self.projector = MlpProjector(self.projector_config) + self.tile_tag = config.tile_tag + self.global_view_pos = config.global_view_pos + + # special token for image token sequence format + n_embed = self.projector_config.n_embed + embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32)) + if self.tile_tag == "2D": + # <|view_separator|>, <|\n|> + self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std) + # This is a typo in original implementation + self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std) + else: + raise ValueError( + f"Only 2D tile_tag is supported currently, got: {self.tile_tag}" + ) + + if self.text_config.topk_method == "noaux_tc": + architectures = ["DeepseekV3ForCausalLM"] + elif not self.text_config.use_mla: + architectures = ["DeepseekForCausalLM"] + else: + architectures = ["DeepseekV2ForCausalLM"] + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=self.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=architectures, + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> DeepseekOCRImagePixelInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + images_spatial_crop = kwargs.pop("images_spatial_crop", None) + images_crop = kwargs.pop("images_crop", None) + + if pixel_values is None or torch.sum(pixel_values).item() == 0: + return None + + if pixel_values is not None: + base_size = self.vision_config.image_size + return DeepseekOCRImagePixelInputs( + type="pixel_values", + data=pixel_values, + images_crop=images_crop, + images_spatial_crop=images_spatial_crop, + resolve_bindings={ + "base_size": base_size, + }, + ) + + raise AssertionError("This line should be unreachable.") + + def _encode_global_features(self, image_tensor: torch.Tensor) -> torch.Tensor: + global_features_1 = self.sam_model(image_tensor) + global_features_2 = self.vision_model(image_tensor, global_features_1) + features = torch.cat( + ( + global_features_2[:, 1:], + global_features_1.flatten(2).permute(0, 2, 1), + ), + dim=-1, + ) + features = self.projector(features) + + _, hw, dim = features.shape + side = int(hw**0.5) + + features = features.view(side, side, dim) + newline = self.image_newline[None, None, :].expand(side, 1, dim) + features = torch.cat([features, newline], dim=1) + return features.view(-1, dim) + + def _encode_local_features( + self, patches: torch.Tensor, crop_shape: torch.Tensor + ) -> torch.Tensor | None: + if torch.sum(patches).item() == 0: + return None + + local_features_1 = self.sam_model(patches) + local_features_2 = self.vision_model(patches, local_features_1) + features = torch.cat( + ( + local_features_2[:, 1:], + local_features_1.flatten(2).permute(0, 2, 1), + ), + dim=-1, + ) + features = self.projector(features) + + _, hw, dim = features.shape + patch_side = int(hw**0.5) + + width_tiles = int(crop_shape[0].item()) + height_tiles = int(crop_shape[1].item()) + + features = ( + features.view(height_tiles, width_tiles, patch_side, patch_side, dim) + .permute(0, 2, 1, 3, 4) + .reshape(height_tiles * patch_side, width_tiles * patch_side, dim) + ) + newline = self.image_newline[None, None, :].expand( + height_tiles * patch_side, 1, dim + ) + features = torch.cat([features, newline], dim=1) + + return features.view(-1, dim) + + def _pixel_values_to_embedding( + self, + pixel_values: torch.Tensor, + images_crop: torch.Tensor, + images_spatial_crop: torch.Tensor, + ) -> NestedTensors: + images_in_this_batch = [] + + is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1) + patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0) + images_crop = images_crop.split(patches_per_image.tolist()) + for jdx in range(images_spatial_crop.size(0)): + patches = images_crop[jdx] + image_ori = pixel_values[[jdx]] + crop_shape = images_spatial_crop[jdx] + + global_features = self._encode_global_features(image_ori) + local_features = self._encode_local_features(patches, crop_shape) + + if local_features is not None: + combined = torch.cat( + [local_features, global_features, self.view_seperator[None, :]], + dim=0, + ) + else: + combined = torch.cat( + [global_features, self.view_seperator[None, :]], dim=0 + ) + + images_in_this_batch.append(combined) + + return images_in_this_batch + + def _process_image_input( + self, image_input: DeepseekOCRImagePixelInputs + ) -> torch.Tensor: + pixel_values = image_input.data + images_crop = image_input.images_crop + images_spatial_crop = image_input.images_spatial_crop.to(dtype=torch.long) + + vision_features = self._pixel_values_to_embedding( + pixel_values=pixel_values, + images_crop=images_crop, + images_spatial_crop=images_spatial_crop, + ) + + return vision_features + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings( + self, **kwargs: object + ) -> MultiModalEmbeddings | None: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ): + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + return autoloaded_weights diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index a2fb0cfe6000..db7b86ffaf96 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -27,7 +27,7 @@ import typing from collections.abc import Callable, Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -49,7 +49,11 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_fusion_shared_expert_enabled, + is_rocm_aiter_moe_enabled, +) from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -58,13 +62,12 @@ RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention +from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -76,8 +79,8 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import cdiv, direct_register_custom_op from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerBackend, DeepseekV32IndexerMetadata, @@ -107,7 +110,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, reduce_results: bool = True, is_sequence_parallel=False, prefix: str = "", @@ -151,9 +154,9 @@ def forward(self, x): class DeepseekV2MoE(nn.Module): def __init__( self, - config: Union[DeepseekV2Config, DeepseekV3Config], + config: DeepseekV2Config | DeepseekV3Config, parallel_config: ParallelConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -204,27 +207,10 @@ def __init__( self.physical_expert_start + self.n_local_physical_experts ) - if config.n_shared_experts is None: - self.experts = FusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, - # we do scaling outside, set factor to 1.0 to avoid double mul - routed_scaling_factor=1.0, - e_score_correction_bias=self.gate.e_score_correction_bias, - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts, - is_sequence_parallel=self.is_sequence_parallel, - ) + if ( + config.n_shared_experts is None + or is_rocm_aiter_fusion_shared_expert_enabled() + ): self.shared_experts = None else: intermediate_size = config.moe_intermediate_size * config.n_shared_experts @@ -239,27 +225,34 @@ def __init__( prefix=f"{prefix}.shared_experts", ) - self.experts = SharedFusedMoE( - shared_experts=self.shared_experts, - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, - # we do scaling outside, set factor to 1.0 to avoid double mul - routed_scaling_factor=1.0, - e_score_correction_bias=self.gate.e_score_correction_bias, - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts, - is_sequence_parallel=self.is_sequence_parallel, - ) + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + gate=self.gate, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + # we do scaling outside, set factor to 1.0 to avoid double mul + # aiter applies routed_scaling_factor internally + routed_scaling_factor=1.0 + if not is_rocm_aiter_moe_enabled() + else self.routed_scaling_factor, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + n_shared_experts=config.n_shared_experts + if is_rocm_aiter_fusion_shared_expert_enabled() + else None, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape @@ -272,23 +265,27 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.is_sequence_parallel: hidden_states = sequence_parallel_chunk(hidden_states) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - fused_moe_out = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) - - if self.shared_experts is not None: - shared_output, final_hidden_states = fused_moe_out + if self.experts.is_internal_router: + # In this case, the gate/router runs inside the FusedMoE class + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=hidden_states + ) else: - shared_output = None - final_hidden_states = fused_moe_out + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + shared_output, final_hidden_states = fused_moe_out + if self.shared_experts is None: + assert shared_output is None # Fix FP16 overflow # See DeepseekV2DecoderLayer for more details. if hidden_states.dtype != torch.float16: - final_hidden_states *= self.routed_scaling_factor + if not is_rocm_aiter_moe_enabled(): + final_hidden_states *= self.routed_scaling_factor elif self.shared_experts is not None: assert shared_output is not None shared_output *= 1.0 / self.routed_scaling_factor @@ -322,7 +319,7 @@ class DeepseekV2Attention(nn.Module): def __init__( self, vllm_config: VllmConfig, - config: Union[DeepseekV2Config, DeepseekV3Config], + config: DeepseekV2Config | DeepseekV3Config, hidden_size: int, num_heads: int, qk_nope_head_dim: int, @@ -331,11 +328,11 @@ def __init__( q_lora_rank: int, kv_lora_rank: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - topk_indices_buffer: Optional[torch.Tensor] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + topk_indices_buffer: torch.Tensor | None = None, prefix: str = "", ) -> None: super().__init__() @@ -490,7 +487,7 @@ def __init__( raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: return MLAAttentionSpec( # Only has one vector instead of K + V block_size=self.cache_config.block_size, num_kv_heads=1, @@ -504,69 +501,6 @@ def get_attn_backend(self) -> AttentionBackend: return DeepseekV32IndexerBackend -@torch.inference_mode() -def cp_gather_indexer_k_quant_cache( - kv_cache, # [num_blocks, block_size, head_dim + 1] - dst_value, # [cu_seq_lens[-1], head_dim] - dst_scale, # [cu_seq_lens[-1], 4] - block_table, # [batch_size, num_blocks] - cu_seq_lens, # [batch_size + 1, ] - batch_size, -): - num_blocks, block_size, _ = kv_cache.shape - head_dim = dst_value.shape[-1] - kv_cache = kv_cache.view(num_blocks, -1) - - expected_value = [] - expected_scale = [] - for b in range(batch_size): - s = cu_seq_lens[b + 1] - cu_seq_lens[b] - if s == 0: - continue - tot = cdiv(s, block_size) - blocks = block_table[b, :tot] - - value = [] - scale = [] - full_block = torch.arange(tot - 1, device=kv_cache.device, dtype=torch.int32) - non_remaining_value = kv_cache[ - blocks[full_block], : block_size * head_dim - ].view(-1, head_dim) - non_remaining_scale = kv_cache[ - blocks[full_block], block_size * head_dim : - ].view(-1, 4) - - remaining = s - (tot - 1) * block_size - - value = torch.cat( - [ - non_remaining_value, - kv_cache[blocks[-1], : remaining * head_dim].view(-1, head_dim), - ], - dim=0, - ) - scale = torch.cat( - [ - non_remaining_scale, - kv_cache[ - blocks[-1], - block_size * head_dim : block_size * head_dim + remaining * 4, - ].view(-1, 4), - ], - dim=0, - ) - - expected_value.append(value) - expected_scale.append(scale) - - gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim) - gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4) - gather_value = gather_value.view(torch.float8_e4m3fn) - gather_scale = gather_scale.view(torch.float32) - dst_value.copy_(gather_value) - dst_scale.copy_(gather_scale) - - def sparse_attn_indexer( hidden_states: torch.Tensor, k_cache_prefix: str, @@ -575,12 +509,12 @@ def sparse_attn_indexer( k: torch.Tensor, weights: torch.Tensor, quant_block_size: int, - scale_fmt: Optional[str], + scale_fmt: str | None, topk_tokens: int, head_dim: int, max_model_len: int, total_seq_lens: int, - topk_indices_buffer: Optional[torch.Tensor], + topk_indices_buffer: torch.Tensor | None, ) -> torch.Tensor: # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata @@ -626,44 +560,38 @@ def sparse_attn_indexer( dtype=torch.float8_e4m3fn, ) k_scale = torch.empty( - [chunk.total_seq_lens, 1], device=k.device, dtype=torch.float32 + [chunk.total_seq_lens, 4], + device=k.device, + dtype=torch.uint8, ) - cp_gather_indexer_k_quant_cache( + ops.cp_gather_indexer_k_quant_cache( kv_cache, k_fp8, k_scale, chunk.block_table, chunk.cu_seq_lens, - chunk.num_reqs, ) logits = fp8_mqa_logits( q_fp8[chunk.token_start : chunk.token_end], - (k_fp8, k_scale), + (k_fp8, k_scale.view(torch.float32)), weights[chunk.token_start : chunk.token_end], chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, ) num_rows = logits.shape[0] assert topk_tokens == 2048, "top_k_per_row assumes size 2048" - topk_indices = torch.empty( - num_rows, topk_tokens, dtype=torch.int32, device=logits.device - ) - topk_values = torch.empty( - num_rows, topk_tokens, dtype=logits.dtype, device=logits.device - ) + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] torch.ops._C.top_k_per_row( logits, chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, topk_indices, - topk_values, num_rows, logits.stride(0), logits.stride(1), ) - topk_indices_buffer[ - chunk.token_start : chunk.token_end, : topk_indices.shape[-1] - ] = topk_indices.to(dtype=torch.int32) if has_decode: decode_metadata = attn_metadata.decode @@ -697,31 +625,15 @@ def sparse_attn_indexer( decode_metadata.schedule_metadata, max_model_len=max_model_len, ) - # padded query len - current_device = padded_q_fp8_decode_tokens.device - padded_num_tokens = batch_size * next_n - row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n - next_n_offset = ( - torch.arange(padded_num_tokens, device=padded_q_fp8_decode_tokens.device) - % next_n - ) - index_end_pos = ( - decode_metadata.seq_lens[row_indices] - next_n + next_n_offset + 1 - ).unsqueeze(1) num_rows = logits.shape[0] assert topk_tokens == 2048, "top_k_per_row assumes size 2048" - topk_indices = torch.empty( - num_rows, topk_tokens, dtype=torch.int32, device=logits.device - ) - topk_values = torch.empty( - num_rows, topk_tokens, dtype=logits.dtype, device=logits.device - ) - torch.ops._C.top_k_per_row( + topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] + + torch.ops._C.top_k_per_row_decode( logits, - torch.zeros(num_rows, dtype=torch.int32, device=logits.device), - index_end_pos.to(dtype=torch.int32, device=logits.device), + next_n, + decode_metadata.seq_lens, topk_indices, - topk_values, num_rows, logits.stride(0), logits.stride(1), @@ -733,9 +645,9 @@ def sparse_attn_indexer( topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), decode_lens, ) - topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( - topk_indices.to(dtype=torch.int32) - ) + topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( + topk_indices + ) return topk_indices_buffer @@ -748,12 +660,12 @@ def sparse_attn_indexer_fake( k: torch.Tensor, weights: torch.Tensor, quant_block_size: int, - scale_fmt: Optional[str], + scale_fmt: str | None, topk_tokens: int, head_dim: int, max_model_len: int, total_seq_lens: int, - topk_indices_buffer: Optional[torch.Tensor], + topk_indices_buffer: torch.Tensor | None, ) -> torch.Tensor: # profile run # NOTE(Chen): create the max possible flattened_kv. So that @@ -779,12 +691,12 @@ class Indexer(nn.Module): def __init__( self, vllm_config: VllmConfig, - config: Union[DeepseekV2Config, DeepseekV3Config], + config: DeepseekV2Config | DeepseekV3Config, hidden_size: int, q_lora_rank: int, - quant_config: Optional[QuantizationConfig], - cache_config: Optional[CacheConfig], - topk_indices_buffer: Optional[torch.Tensor], + quant_config: QuantizationConfig | None, + cache_config: CacheConfig | None, + topk_indices_buffer: torch.Tensor | None, prefix: str = "", ): super().__init__() @@ -901,21 +813,21 @@ class DeepseekV2MLAAttention(nn.Module): def __init__( self, vllm_config: VllmConfig, - config: Union[DeepseekV2Config, DeepseekV3Config], + config: DeepseekV2Config | DeepseekV3Config, hidden_size: int, num_heads: int, qk_nope_head_dim: int, qk_rope_head_dim: int, v_head_dim: int, - q_lora_rank: Optional[int], + q_lora_rank: int | None, kv_lora_rank: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", - topk_indices_buffer: Optional[torch.Tensor] = None, + topk_indices_buffer: torch.Tensor | None = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -1038,7 +950,7 @@ def __init__( topk_indices_buffer=topk_indices_buffer, ) - self.mla_attn = MultiHeadLatentAttention( + self.mla_attn = MultiHeadLatentAttentionWrapper( self.hidden_size, self.num_local_heads, self.scaling, @@ -1066,8 +978,8 @@ def __init__( self, vllm_config: VllmConfig, prefix: str, - config: Optional[DeepseekV2Config] = None, - topk_indices_buffer: Optional[torch.Tensor] = None, + config: DeepseekV2Config | None = None, + topk_indices_buffer: torch.Tensor | None = None, ) -> None: super().__init__() @@ -1138,7 +1050,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> torch.Tensor: # Self Attention if residual is None: @@ -1186,6 +1098,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config + self.device = current_platform.device_type self.vocab_size = config.vocab_size self.is_v32 = hasattr(config, "index_topk") @@ -1195,7 +1108,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.scheduler_config.max_num_batched_tokens, topk_tokens, dtype=torch.int32, - device="cuda", + device=self.device, ) else: topk_indices_buffer = None @@ -1233,9 +1146,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -1306,7 +1219,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace self.num_expert_groups = config.n_group - self.moe_layers: list[FusedMoE] = [] + self.moe_layers: list[SharedFusedMoE] = [] example_moe = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): @@ -1368,9 +1281,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -1379,10 +1292,21 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return SharedFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + num_redundant_experts=0, + ) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -1394,11 +1318,16 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( + expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts, + num_experts=self.config.n_routed_experts + + ( + self.config.n_shared_experts + if is_rocm_aiter_fusion_shared_expert_enabled() + else 0 + ), num_redundant_experts=self.num_redundant_experts, ) @@ -1412,6 +1341,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if spec_layer is not None: continue # skip spec decode layers for main model + is_fuse_shared_experts_layer = ( + is_rocm_aiter_fusion_shared_expert_enabled() + and ("mlp.shared_experts" in name) + ) + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: @@ -1424,6 +1358,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # for mlp.experts[0].gate_gate_up_proj, which breaks load. if ("mlp.experts." in name) and name not in params_dict: continue + if is_fuse_shared_experts_layer: + continue name_mapped = name.replace(weight_name, param_name) # QKV fusion is optional, fall back to normal @@ -1448,65 +1384,115 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: break else: is_expert_weight = False - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - - # Anyway, this is an expert weight and should not be - # attempted to load as other weights later - is_expert_weight = True - - # Do not modify `name` since the loop may continue here - # Instead, create a new variable - name_mapped = name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name_mapped, self): - continue - - param = params_dict[name_mapped] - # We should ask the weight loader to return success or not - # here since otherwise we may skip experts with other - # available replicas. - weight_loader = typing.cast( - Callable[..., bool], param.weight_loader - ) - success = weight_loader( - param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True, - ) - if success: - name = name_mapped - break - else: - if is_expert_weight: - # We've checked that this is an expert weight - # However it's not mapped locally to this rank - # So we simply skip it - continue - - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader + + # Special handling: when AITER fusion_shared_experts is enabled, + # checkpoints may provide a single widened shared_experts tensor + # without explicit expert indices + # (e.g. ...mlp.shared_experts.gate_proj.weight). + # For models with multiple shared experts, split that tensor + # evenly into per-shared-expert slices and load them into + # appended expert slots mlp.experts.{n_routed_experts + j}.* + # accordingly. + num_chunks = 1 + if is_fuse_shared_experts_layer: + num_chunks = getattr(self.config, "n_shared_experts", 1) or 1 + # Determine split axis based on op type + # gate/up: ColumnParallel → split along dim 0 + # down: RowParallel → split along dim 1 + split_dim = 1 if "down_proj.weight" in name else 0 + total = loaded_weight.shape[split_dim] + assert total % num_chunks == 0, ( + f"Shared expert weight dim {total} " + f"not divisible by num_chunks {num_chunks}" ) - weight_loader(param, loaded_weight) - loaded_params.add(name) + chunk_size = total // num_chunks + + for j in range(num_chunks): + chunk_name = name + weight_to_load = loaded_weight + + if is_fuse_shared_experts_layer: + if split_dim == 0: + weight_to_load = loaded_weight[ + j * chunk_size : (j + 1) * chunk_size, : + ] + else: + weight_to_load = loaded_weight[ + :, j * chunk_size : (j + 1) * chunk_size + ] + # Synthesize an expert-style name so expert mapping + # can route it + chunk_name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts + j}", + ) + + # Use expert_params_mapping to locate the destination + # param and delegate to its expert-aware weight_loader + # with expert_id. + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in chunk_name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = chunk_name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or + # not here since otherwise we may skip experts with + # other available replicas. + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + weight_to_load, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + if not is_fuse_shared_experts_layer: + name = name_mapped + else: + loaded_params.add(name_mapped) + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + if not is_fuse_shared_experts_layer: + loaded_params.add(name) return loaded_params @@ -1518,8 +1504,8 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): # Compatibility with # https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py def get_spec_layer_idx_from_weight_name( - config: Union[DeepseekV2Config, DeepseekV3Config], weight_name: str -) -> Optional[int]: + config: DeepseekV2Config | DeepseekV3Config, weight_name: str +) -> int | None: if ( hasattr(config, "num_nextn_predict_layers") and config.num_nextn_predict_layers > 0 diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 8226e88c47a2..ea10245a84ee 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -6,7 +6,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal, TypeAlias import torch import torch.nn as nn @@ -18,8 +18,7 @@ from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.model_executor.models.transformers import replace_linear_class +from vllm.model_executor.models.transformers.utils import replace_linear_class from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, @@ -49,8 +48,9 @@ ) from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config -from vllm.utils import is_list_of +from vllm.utils.collection_utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape +from vllm.utils.torch_utils import set_default_torch_dtype from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import ( @@ -88,14 +88,12 @@ class DeepseekVL2VImageEmbeddingInputs(TensorSchema): """ type: Literal["image_embeds"] - data: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], TensorShape("bn", "f", "h") - ] + data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("bn", "f", "h")] -DeepseekVL2ImageInputs = Union[ - DeepseekVL2ImagePixelInputs, DeepseekVL2VImageEmbeddingInputs -] +DeepseekVL2ImageInputs: TypeAlias = ( + DeepseekVL2ImagePixelInputs | DeepseekVL2VImageEmbeddingInputs +) class MlpProjector(nn.Module): @@ -103,9 +101,10 @@ def __init__(self, cfg: MlpProjectorConfig): super().__init__() self.cfg = cfg + self.projector_type = cfg.projector_type assert not cfg.token_pooling, "Token pooling is not supported currently." - if cfg.projector_type == "downsample_mlp_gelu": + if self.projector_type == "downsample_mlp_gelu": mlp_depth = cfg.depth mlp_ratio = cfg.mlp_ratio modules = [ @@ -122,7 +121,8 @@ def __init__(self, cfg: MlpProjectorConfig): modules.append(nn.GELU()) modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) modules = nn.Sequential(*modules) - + elif self.projector_type == "linear": + modules = nn.Linear(cfg.input_dim, cfg.n_embed) else: raise NotImplementedError( f"Unsupported projector type: {cfg.projector_type}" @@ -132,24 +132,25 @@ def __init__(self, cfg: MlpProjectorConfig): def forward(self, x): bs, hw, input_dim = x.shape - h = w = int((hw) ** 0.5) - """compute padding""" - if h % self.cfg.downsample_ratio: - pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio - else: - pad = 0 - x = x.reshape(bs, h, w, input_dim) - if pad > 0: - x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0) - """4 to 1 concat""" - x = x.permute(0, 3, 1, 2) # B, C, H, W - x = F.unfold( - x, - kernel_size=self.cfg.downsample_ratio, - stride=self.cfg.downsample_ratio, - padding=0, - ) # B, C*4, HW // 4 - x = x.permute(0, 2, 1) + if self.projector_type == "downsample_mlp_gelu": + h = w = int((hw) ** 0.5) + """compute padding""" + if h % self.cfg.downsample_ratio: + pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio + else: + pad = 0 + x = x.reshape(bs, h, w, input_dim) + if pad > 0: + x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0) + """4 to 1 concat""" + x = x.permute(0, 3, 1, 2) # B, C, H, W + x = F.unfold( + x, + kernel_size=self.cfg.downsample_ratio, + stride=self.cfg.downsample_ratio, + padding=0, + ) # B, C*4, HW // 4 + x = x.permute(0, 2, 1) return self.layers(x) @@ -161,7 +162,7 @@ def get_hf_config(self): def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor(DeepseekVLV2Processor, **kwargs) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_num_image_tokens( @@ -214,7 +215,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) @@ -310,11 +311,11 @@ def get_replacement_deepseek_vl2(item_idx: int): def _cached_apply_hf_processor( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - mm_uuids: Optional[MultiModalUUIDDict] = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: # The processor logic is different for len(images) <= 2 vs > 2 # Since the processing cache assumes that the processor output is @@ -353,7 +354,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<image>" @@ -454,7 +455,7 @@ def patch_vit_for_tp(self, vit: torch.nn.Module, quant_config: QuantizationConfi def _init_vision_module( self, vision_config: VisionEncoderConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, prefix: str = "", ) -> nn.Module: # TODO: refactor vision model through timm wrapper from transformers @@ -480,7 +481,7 @@ def _init_vision_module( def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[DeepseekVL2ImageInputs]: + ) -> DeepseekVL2ImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) images_spatial_crop = kwargs.pop("images_spatial_crop", None) image_embeds = kwargs.pop("image_embeds", None) @@ -637,8 +638,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ): if intermediate_tensors is not None: @@ -653,7 +654,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py index 1ae7457fb215..c33cb3d84478 100644 --- a/vllm/model_executor/models/dots1.py +++ b/vllm/model_executor/models/dots1.py @@ -27,7 +27,7 @@ from collections.abc import Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -42,7 +42,7 @@ tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -80,7 +80,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, reduce_results: bool = True, prefix: str = "", ) -> None: @@ -117,7 +117,7 @@ class Dots1MoE(nn.Module): def __init__( self, config: Dots1Config, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -145,7 +145,21 @@ def __init__( else: self.gate.e_score_correction_bias = None - self.experts = FusedMoE( + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = Dots1MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + else: + self.shared_experts = None + + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, @@ -163,29 +177,19 @@ def __init__( e_score_correction_bias=self.gate.e_score_correction_bias, ) - if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = Dots1MLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False, - prefix=f"{prefix}.shared_experts", - ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - if self.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) + router_logits, _ = self.gate(hidden_states) final_hidden_states = ( self.experts(hidden_states=hidden_states, router_logits=router_logits) * self.routed_scaling_factor ) - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output + + if self.shared_experts is not None: + final_hidden_states = final_hidden_states[0] + final_hidden_states[1] + if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) @@ -199,10 +203,10 @@ def __init__( num_kv_heads: int, config: Dots1Config, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -285,8 +289,8 @@ def __init__( config: Dots1Config, prefix: str, model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -334,7 +338,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> torch.Tensor: if residual is None: residual = hidden_states @@ -399,9 +403,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -426,7 +430,7 @@ def forward( return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return FusedMoE.make_expert_params_mapping( + return SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", @@ -542,9 +546,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, @@ -556,7 +560,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 1bc50f27269e..6d462ad8ae62 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal, TypeAlias import torch import torch.nn as nn @@ -92,7 +92,7 @@ class DotsOCRImageEmbeddingInputs(TensorSchema): image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] -DotsOCRImageInputs = Union[DotsOCRImagePixelInputs, DotsOCRImageEmbeddingInputs] +DotsOCRImageInputs: TypeAlias = DotsOCRImagePixelInputs | DotsOCRImageEmbeddingInputs class DotsOCRDummyInputsBuilder(Qwen2VLDummyInputsBuilder): @@ -104,7 +104,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) @@ -134,7 +134,7 @@ def get_hf_config(self) -> DotsOCRConfig: return config - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_mm_max_tokens_per_item( @@ -253,9 +253,10 @@ def __init__( num_heads: int = 16, bias: bool = True, *, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() @@ -288,7 +289,9 @@ def __init__( ) # Select attention backend self.attn_backend = get_vit_attn_backend( - self.hidden_size_per_attention_head, torch.get_default_dtype() + self.hidden_size_per_attention_head, + torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) self.use_upstream_fa = False @@ -296,6 +299,7 @@ def __init__( maybe_get_vit_flash_attn_backend( self.attn_backend, self.use_upstream_fa, + attn_backend_override=attn_backend_override, ) ) if self.attn_backend not in { @@ -316,10 +320,10 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - rotary_pos_emb: Optional[torch.Tensor] = None, + rotary_pos_emb: torch.Tensor | None = None, *, - max_seqlen: Optional[int] = None, - seqlens: Optional[list[int]] = None, + max_seqlen: int | None = None, + seqlens: list[int] | None = None, ) -> torch.Tensor: # [S, C] -> [S, B=1, C] x = hidden_states.unsqueeze(1) @@ -394,7 +398,7 @@ def __init__( self, config, *, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ): @@ -507,9 +511,10 @@ def __init__( self, config, *, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ): super().__init__() @@ -521,6 +526,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.attn", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) self.mlp = DotsSwiGLUFFN( @@ -537,8 +543,8 @@ def forward( *, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, - seqlens: Optional[list[int]] = None, + max_seqlen: int | None = None, + seqlens: list[int] | None = None, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), @@ -555,12 +561,13 @@ class DotsVisionTransformer(nn.Module): def __init__( self, config: DotsVisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, - num_hidden_layers_override: Optional[int] = None, - require_post_norm: Optional[bool] = None, + num_hidden_layers_override: int | None = None, + require_post_norm: bool | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() self.config = config @@ -571,7 +578,9 @@ def __init__( head_dim = config.embed_dim // config.num_attention_heads self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype() + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( torch.get_default_dtype() @@ -591,6 +600,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.blocks.{i}", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) for i in range(num_layers) ] @@ -653,7 +663,7 @@ def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor: def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor - ) -> tuple[Optional[int], Optional[list[int]]]: + ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None if ( self.attn_backend == _Backend.FLASH_ATTN @@ -680,7 +690,7 @@ def forward( dim=0, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) for blk in self.blocks: @@ -734,7 +744,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA supports_encoder_tp_data = True @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|img|><|imgpad|><|endofimg|>" @@ -750,11 +760,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config.vision_config = vision_config else: vision_config = self.config.vision_config + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.vision_tower = DotsVisionTransformer( vision_config, quant_config=self.quant_config, prefix=maybe_prefix(prefix, "vision_tower"), use_data_parallel=self.use_data_parallel, + attn_backend_override=attn_backend_override, ) self.language_model: Qwen2ForCausalLM = init_vllm_registered_model( vllm_config=vllm_config, @@ -765,7 +781,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[DotsOCRImageInputs]: + ) -> DotsOCRImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -834,10 +850,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None elif inputs_embeds is None: @@ -861,7 +877,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index 3cb93177a383..607589e68ef3 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -23,9 +23,10 @@ # limitations under the License. """Inference-only ErineMoE model compatible with HuggingFace weights.""" -from collections.abc import Iterable +import typing +from collections.abc import Callable, Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -33,11 +34,15 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_world_size, +) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -58,7 +63,7 @@ ) from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, PPMissingLayer, @@ -79,7 +84,7 @@ def __init__( intermediate_size: int, hidden_act: str, use_bias: bool = False, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, reduce_results: bool = True, prefix: str = "", ) -> None: @@ -116,14 +121,36 @@ class Ernie4_5_MoeMoE(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", + enable_eplb: bool = False, ): super().__init__() layer_idx = extract_layer_index(prefix) self.layer_idx = layer_idx self.tp_size = get_tensor_model_parallel_world_size() + + self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts", None) + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.moe_num_experts + self.n_shared_experts: int = self.moe_num_shared_experts + + # Load balancing settings. + vllm_config = get_current_vllm_config() + eplb_config = vllm_config.parallel_config.eplb_config + self.enable_eplb = enable_eplb + + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) self.has_shared_experts = getattr(config, "moe_num_shared_experts", 0) > 0 if self.tp_size > config.moe_num_experts: @@ -145,18 +172,6 @@ def __init__( torch.empty(config.moe_num_experts, dtype=torch.float32) ) - self.experts = FusedMoE( - num_experts=config.moe_num_experts, - top_k=config.moe_k, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=True, - quant_config=quant_config, - prefix=f"{prefix}.experts", - e_score_correction_bias=self.gate.e_score_correction_bias, - ) - if self.has_shared_experts: intermediate_size = ( config.moe_intermediate_size * config.moe_num_shared_experts @@ -167,16 +182,30 @@ def __init__( hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.shared_experts", - reduce_results=self.experts.must_reduce_shared_expert_outputs(), + reduce_results=False, ) + else: + self.shared_experts = None + + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.moe_num_experts, + top_k=config.moe_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=True, + quant_config=quant_config, + prefix=f"{prefix}.experts", + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) - shared_output = None - if self.has_shared_experts: - shared_output = self.shared_experts(hidden_states) router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) @@ -184,8 +213,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states=hidden_states, router_logits=router_logits ) - if self.has_shared_experts and shared_output is not None: - final_hidden_states = final_hidden_states + shared_output + if self.has_shared_experts: + final_hidden_states = final_hidden_states[0] + final_hidden_states[1] if self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( @@ -201,14 +230,14 @@ def __init__( hidden_size: int, num_heads: int, num_kv_heads: int, - head_dim: Optional[int] = None, + head_dim: int | None = None, rope_theta: float = 500000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 131072, rms_norm_eps: float = 1e-05, qkv_bias: bool = False, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -295,9 +324,10 @@ class Ernie4_5_MoeDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", + enable_eplb: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -338,7 +368,10 @@ def __init__( and layer_idx <= moe_layer_end_index ): self.mlp = Ernie4_5_MoeMoE( - config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, ) else: self.mlp = Ernie4_5_MoeMLP( @@ -359,7 +392,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> torch.Tensor: # Self Attention if residual is None: @@ -393,6 +426,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.config = config + parallel_config = vllm_config.parallel_config + eplb_config = parallel_config.eplb_config + enable_eplb = parallel_config.enable_eplb + + self.num_redundant_experts = eplb_config.num_redundant_experts if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( @@ -411,6 +449,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config=cache_config, quant_config=quant_config, prefix=prefix, + enable_eplb=enable_eplb, ), prefix=f"{prefix}.layers", ) @@ -431,9 +470,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -460,11 +499,12 @@ def forward( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return FusedMoE.make_expert_params_mapping( + return SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.moe_num_experts, + num_redundant_experts=self.num_redundant_experts, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -513,34 +553,54 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weight_loader(param, loaded_weight, shard_id) break else: + is_expert_weight = False for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) # Skip layers on other devices. - if is_pp_missing_parameter(name, self): + if is_pp_missing_parameter(name_mapped, self): continue # Skip loading extra bias for GPTQ models. if ( - name.endswith(".bias") or name.endswith("_bias") - ) and name not in params_dict: + name_mapped.endswith(".bias") or name_mapped.endswith("_bias") + ) and name_mapped not in params_dict: continue - param = params_dict[name] - - weight_loader = param.weight_loader - weight_loader( + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( param, loaded_weight, - name, + name_mapped, shard_id=shard_id, expert_id=expert_id, + return_success=True, ) - break + if success: + name = name_mapped + break else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + # Skip loading extra bias for GPTQ models. if ( name.endswith(".bias") or name.endswith("_bias") @@ -563,7 +623,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded_params -class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): +class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExperts): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -605,6 +665,81 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model.make_empty_intermediate_tensors ) + self.expert_weights = [] + + # Set MoE hyperparameters + moe_layers_indices = [ + i + for i in range(config.num_hidden_layers) + if ( + i >= config.moe_layer_start_index + and i <= config.moe_layer_end_index + and (i + 1) % config.moe_layer_interval == 0 + ) + ] + self.num_moe_layers = len(moe_layers_indices) + self.num_expert_groups = 1 + + self.moe_layers: list[SharedFusedMoE] = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, Ernie4_5_MoeDecoderLayer) + if isinstance(layer.mlp, Ernie4_5_MoeMoE): + example_moe = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_moe is None: + logger.warning("No Ernie4_5_MoeMoE layer found in model.layers.") + self.num_logical_experts = 0 + self.num_physical_experts = 0 + self.num_local_physical_experts = 0 + self.num_routed_experts = 0 + self.num_shared_experts = 0 + self.num_redundant_experts = 0 + else: + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.model.layers: + if isinstance(layer.mlp, Ernie4_5_MoeMoE): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -612,9 +747,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -623,7 +758,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 493260cf73ef..86536b21c33f 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -23,17 +23,18 @@ # limitations under the License. """Inference-only Erine VL model compatible with HuggingFace weights.""" +import itertools import math -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial -from typing import Annotated, Any, Callable, Literal, Optional, Union +from typing import Annotated, Any, Literal import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat -from transformers import BatchFeature +from transformers import BatchFeature, PretrainedConfig from vllm.attention.backends.registry import _Backend from vllm.attention.layer import ( @@ -76,6 +77,7 @@ from .interfaces import ( MultiModalEmbeddings, SupportsLoRA, + SupportsMRoPE, SupportsMultiModal, SupportsPP, ) @@ -160,8 +162,9 @@ def __init__( embed_dim: int, num_heads: int, projection_size: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -194,6 +197,7 @@ def __init__( self.attn_backend = get_vit_attn_backend( head_size=self.hidden_size_per_attention_head, dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) self.use_upstream_fa = False @@ -202,6 +206,7 @@ def __init__( maybe_get_vit_flash_attn_backend( self.attn_backend, self.use_upstream_fa, + attn_backend_override=attn_backend_override, ) ) @@ -252,8 +257,8 @@ def forward( x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + max_seqlen: int | None = None, # Only used for Flash Attention + seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -330,7 +335,7 @@ def __init__( in_features: int, hidden_features: int, act_layer: type[nn.Module] = QuickGELU, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -362,9 +367,10 @@ def __init__( num_heads: int, mlp_ratio: float, act_layer: type[nn.Module] = QuickGELU, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, + norm_layer: Callable[[int], nn.Module] | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() @@ -380,6 +386,7 @@ def __init__( projection_size=dim, quant_config=quant_config, prefix=f"{prefix}.attn", + attn_backend_override=attn_backend_override, ) self.mlp = Ernie4_5_VisionMLP( @@ -395,8 +402,8 @@ def forward( hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + max_seqlen: int | None = None, # Only used for Flash Attention + seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), @@ -454,8 +461,9 @@ def __init__( self, vision_config, norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() patch_size = vision_config.patch_size @@ -491,6 +499,7 @@ def __init__( norm_layer=norm_layer, quant_config=quant_config, prefix=f"{prefix}.blocks.{layer_idx}", + attn_backend_override=attn_backend_override, ) for layer_idx in range(depth) ] @@ -502,7 +511,9 @@ def __init__( self.ln = nn.LayerNorm(hidden_size, eps=1e-6) self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype() + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( torch.get_default_dtype() @@ -551,7 +562,7 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor - ) -> tuple[Optional[int], Optional[list[int]]]: + ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None if ( self.attn_backend == _Backend.FLASH_ATTN @@ -574,11 +585,12 @@ def forward( grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] ).cumsum(dim=0, dtype=torch.int32) + zeros = cu_seqlens.new_zeros(1) if num_pad > 0: - cu_seqlens = F.pad(cu_seqlens, (1, 1), value=0) + cu_seqlens = torch.cat([zeros, cu_seqlens, zeros]) cu_seqlens[-1] = cu_seqlens[-2] + num_pad else: - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + cu_seqlens = torch.cat([zeros, cu_seqlens]) # add batch size if hidden_states.ndim == 2: @@ -656,15 +668,15 @@ class Ernie4_5_VLVideoPixelInputs(TensorSchema): # === Vision Processor === # -def round_by_factor(number: Union[int, float], factor: int) -> int: +def round_by_factor(number: int | float, factor: int) -> int: return round(number / factor) * factor -def ceil_by_factor(number: Union[int, float], factor: int) -> int: +def ceil_by_factor(number: int | float, factor: int) -> int: return math.ceil(number / factor) * factor -def floor_by_factor(number: Union[int, float], factor: int) -> int: +def floor_by_factor(number: int | float, factor: int) -> int: return math.floor(number / factor) * factor @@ -898,7 +910,7 @@ def get_hf_processor(self, **kwargs: object): def get_image_processor(self, **kwargs: object): return self.get_hf_processor(**kwargs).image_processor - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None, "video": None} def get_mm_max_tokens_per_item( @@ -917,7 +929,7 @@ def _get_vision_info( image_height: int, num_frames: int = 1, do_resize: bool = True, - image_processor: Optional[Any], + image_processor: Any | None, ) -> tuple[ImageSize, int]: if image_processor is None: image_processor = self.get_image_processor() @@ -954,7 +966,7 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - image_processor: Optional[Any], + image_processor: Any | None, ) -> int: _, num_image_tokens = self._get_vision_info( image_width=image_width, @@ -969,7 +981,7 @@ def get_num_video_tokens( image_width: int, image_height: int, num_frames: int, - image_processor: Optional[Any], + image_processor: Any | None, ) -> int: _, num_video_tokens = self._get_vision_info( image_width=image_width, @@ -1086,7 +1098,7 @@ def _pixel_values_norm( pixel_values = ( rescale_factor * pixel_values.to(torch.float32) - image_mean_tensor ) / image_std_tensor - pixel_values = pixel_values.to(hf_config.torch_dtype) + pixel_values = pixel_values.to(hf_config.dtype) return pixel_values def _call_hf_processor( @@ -1234,7 +1246,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -1270,7 +1282,7 @@ def get_dummy_mm_data( dummy_inputs=Ernie4_5_VLDummyInputsBuilder, ) class Ernie4_5_VLMoeForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): merge_by_field_config = True @@ -1307,7 +1319,7 @@ class Ernie4_5_VLMoeForConditionalGeneration( ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" if modality.startswith("video"): @@ -1324,11 +1336,17 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: self.config = config self.multimodal_config = multimodal_config + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.vision_model = Ernie4_5_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "vision_model"), + attn_backend_override=attn_backend_override, ) self.language_model = Ernie4_5_VLMoeForCausalLM( @@ -1353,7 +1371,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """compute logits""" return self.language_model.compute_logits(hidden_states) @@ -1387,12 +1405,156 @@ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: else: self.visual_token_mask = None + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, + context_len: int = 0, + seq_len: int | None = None, + second_per_grid_ts: list[float] | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value for Ernie VL.""" + + image_token_id = hf_config.im_patch_id + video_start_token_id = hf_config.video_start_token_id + video_end_token_id = hf_config.video_end_token_id + spatial_conv_size = hf_config.spatial_conv_size + temporal_conv_size = hf_config.temporal_conv_size + llm_pos_ids_list: list = [] + + if not (image_grid_thw is None and video_grid_thw is None): + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + + input_token_type: list[str] = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if (token == image_token_id) and (video_check_flg is False): + input_token_type.append("image") + elif (token == image_token_id) and (video_check_flg is True): + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group: list[tuple[str, int, int]] = [] + for key, group_iter in itertools.groupby( + enumerate(input_token_type), lambda x: x[1] + ): + group_list = list(group_iter) + start_index = group_list[0][0] + end_index = group_list[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + video_frame_num = 1 + mm_data_idx = 0 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) + if modality_type == "image": + t, h, w = ( + image_grid_thw[mm_data_idx][0], + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_conv_size, + w // spatial_conv_size, + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx + ) + mm_data_idx += 1 + + elif modality_type == "video": + t, h, w = ( + video_grid_thw[mm_data_idx][0], + video_grid_thw[mm_data_idx][1], + video_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t // temporal_conv_size, + h // spatial_conv_size, + w // spatial_conv_size, + ) + + for t_idx in range(llm_grid_t): + t_index = ( + torch.tensor(t_idx) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(1, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(1, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx + ) + + mm_data_idx += 1 + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + video_frame_num = 1 + + else: + text_len = len(input_tokens) + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:seq_len] + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + return llm_positions, mrope_position_delta + def get_language_model(self) -> torch.nn.Module: return self.language_model def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[Ernie4_5_VLImageInputs]: + ) -> Ernie4_5_VLImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1408,7 +1570,7 @@ def _parse_and_validate_image_input( def _parse_and_validate_video_input( self, **kwargs: object - ) -> Optional[Ernie4_5_VLVideoInputs]: + ) -> Ernie4_5_VLVideoInputs | None: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -1483,7 +1645,7 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def get_multimodal_embeddings( self, **kwargs: object - ) -> Optional[MultiModalEmbeddings]: + ) -> MultiModalEmbeddings | None: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return None @@ -1497,21 +1659,21 @@ def get_multimodal_embeddings( for modality in modalities: if modality == "images": image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) - multimodal_embeddings += vision_embeddings + image_embeddings = self._process_image_input(image_input) + multimodal_embeddings += tuple(image_embeddings) if modality == "videos": video_input = modalities["videos"] video_embeddings = self._process_video_input(video_input) - multimodal_embeddings += video_embeddings + multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, *, - is_multimodal: Optional[torch.Tensor] = None, + is_multimodal: torch.Tensor | None = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: @@ -1532,8 +1694,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, ): forward_kwargs = { diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index 51f49b8587e6..d002d1838c8e 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -25,7 +25,7 @@ from collections.abc import Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -37,7 +37,7 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( QKVParallelLinear, @@ -74,7 +74,15 @@ class Ernie4_5_VLMoeMLP(Ernie4_5_MoeMLP): - pass + def __init__(self, shared_experts: torch.nn.Module | None = None, **kwargs): + super().__init__(**kwargs) + self.shared_experts = shared_experts + + def forward(self, x): + if self.shared_experts is not None: + return self.shared_experts(x) + super().forward(x) + else: + return super().forward(x) class Ernie4_5_VLMoeAttention(nn.Module): @@ -83,15 +91,15 @@ def __init__( hidden_size: int, num_heads: int, num_kv_heads: int, - head_dim: Optional[int] = None, + head_dim: int | None = None, rope_theta: float = 500000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, freq_allocation: int = 20, max_position_embeddings: int = 131072, rms_norm_eps: float = 1e-05, qkv_bias: bool = False, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -184,7 +192,7 @@ class Ernie4_5_VLMoeMoE(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -223,6 +231,21 @@ def __init__( assert text_moe_layer_start_index <= text_moe_layer_end_index + if self.has_shared_experts: + intermediate_size = ( + config.moe_intermediate_size[0] * config.moe_num_shared_experts + ) + self.shared_experts = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.shared_experts", + reduce_results=False, + ) + else: + self.shared_experts = None + if ( layer_idx >= text_moe_layer_start_index and layer_idx <= text_moe_layer_end_index @@ -236,7 +259,8 @@ def __init__( prefix=f"{prefix}.text_experts_gate", ) - self.text_experts = FusedMoE( + self.text_experts = SharedFusedMoE( + shared_experts=self.shared_experts, num_experts=config.moe_num_experts[0], top_k=config.moe_k, hidden_size=config.hidden_size, @@ -249,6 +273,7 @@ def __init__( ) else: self.text_experts = Ernie4_5_VLMoeMLP( + shared_experts=self.shared_experts, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, @@ -271,7 +296,8 @@ def __init__( prefix=f"{prefix}.vision_experts_gate", ) - self.vision_experts = FusedMoE( + self.vision_experts = SharedFusedMoE( + shared_experts=self.shared_experts, num_experts=config.moe_num_experts[1], top_k=config.moe_k, hidden_size=config.hidden_size, @@ -284,6 +310,7 @@ def __init__( ) else: self.vision_experts = Ernie4_5_VLMoeMLP( + shared_experts=self.shared_experts, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, @@ -292,19 +319,6 @@ def __init__( prefix=f"{prefix}.mlp", ) - if self.has_shared_experts: - intermediate_size = ( - config.moe_intermediate_size[0] * config.moe_num_shared_experts - ) - self.shared_experts = Ernie4_5_VLMoeMLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.shared_experts", - reduce_results=self.text_experts.must_reduce_shared_expert_outputs(), - ) - def forward( self, hidden_states: torch.Tensor, @@ -315,9 +329,6 @@ def forward( hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) - if self.has_shared_experts: - shared_output = self.shared_experts(hidden_states) - if visual_token_mask is not None and visual_token_mask.all(): # only vision modal input router_logits, _ = self.vision_experts_gate( @@ -330,7 +341,10 @@ def forward( # text and vision modals input visual_token_mask = visual_token_mask.repeat(1, self.hidden_size).bool() text_token_mask = ~visual_token_mask - final_hidden_states = torch.zeros_like(hidden_states) + final_experts_hidden_states = torch.zeros_like(hidden_states) + final_shared_ouput = ( + torch.zeros_like(hidden_states) if self.has_shared_experts else None + ) text_hidden_states = hidden_states[text_token_mask].reshape( -1, self.hidden_size @@ -342,16 +356,26 @@ def forward( text_router_logits, _ = self.text_experts_gate( text_hidden_states.to(dtype=torch.float32) ) - final_hidden_states[text_token_mask] = self.text_experts( + text_shared_ouput, text_experts_output = self.text_experts( hidden_states=text_hidden_states, router_logits=text_router_logits - ).flatten() + ) + final_experts_hidden_states[text_token_mask] = text_experts_output.flatten() + if self.has_shared_experts: + final_shared_ouput[text_token_mask] = text_shared_ouput.flatten() vision_router_logits, _ = self.vision_experts_gate( vision_hidden_states.to(dtype=torch.float32) ) - final_hidden_states[visual_token_mask] = self.vision_experts( + vision_shared_ouput, vision_experts_output = self.vision_experts( hidden_states=vision_hidden_states, router_logits=vision_router_logits - ).flatten() + ) + final_experts_hidden_states[visual_token_mask] = ( + vision_experts_output.flatten() + ) + if self.has_shared_experts: + final_shared_ouput[visual_token_mask] = vision_shared_ouput.flatten() + + final_hidden_states = (final_shared_ouput, final_experts_hidden_states) else: # only text modal input text_router_logits, _ = self.text_experts_gate( @@ -362,8 +386,12 @@ def forward( hidden_states=hidden_states, router_logits=text_router_logits ) - if self.has_shared_experts and shared_output is not None: - final_hidden_states = final_hidden_states + shared_output + if self.has_shared_experts: + # for shared_experts model + final_hidden_states = final_hidden_states[0] + final_hidden_states[1] + else: + # for not shared_experts model + final_hidden_states = final_hidden_states[1] if self.tp_size > 1: final_hidden_states = ( @@ -379,8 +407,8 @@ class Ernie4_5_VLMoeDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -452,8 +480,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - visual_token_mask: Optional[torch.Tensor], + residual: torch.Tensor | None, + visual_token_mask: torch.Tensor | None, **kwargs: object, ) -> torch.Tensor: # Self Attention @@ -540,11 +568,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - visual_token_mask: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + visual_token_mask: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -621,10 +649,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs ) @@ -633,7 +661,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits @@ -649,7 +677,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( + expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/vllm/model_executor/models/ernie_mtp.py b/vllm/model_executor/models/ernie_mtp.py index 46a7131f2499..e7036840388c 100644 --- a/vllm/model_executor/models/ernie_mtp.py +++ b/vllm/model_executor/models/ernie_mtp.py @@ -24,7 +24,6 @@ """Inference-only Ernie-MTP model.""" from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn @@ -121,7 +120,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, previous_hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, spec_step_idx: int = 0, ) -> torch.Tensor: if inputs_embeds is None: @@ -169,8 +168,8 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, spec_step_idx: int = 0, ) -> torch.Tensor: assert spec_step_idx == 0, "ernie_mtp only support predict one token" @@ -183,7 +182,7 @@ def compute_logits( self, hidden_states: torch.Tensor, spec_step_idx: int = 0, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.model.compute_logits(hidden_states, self.lm_head, spec_step_idx) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 1f0b5723721c..84fb52d13854 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -27,7 +27,7 @@ from collections.abc import Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -75,7 +75,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, prefix: str = "", ) -> None: @@ -115,11 +115,11 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, - cache_config: Optional[CacheConfig] = None, + cache_config: CacheConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -209,11 +209,11 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, - cache_config: Optional[CacheConfig] = None, + cache_config: CacheConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -246,8 +246,8 @@ class ExaoneDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -296,7 +296,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -369,11 +369,11 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -536,9 +536,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: model_output = self.transformer( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -547,7 +547,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index 230a2c80104b..d5e4d9a1486f 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -23,7 +23,7 @@ from collections.abc import Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -72,7 +72,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, prefix: str = "", ) -> None: @@ -112,11 +112,11 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 1000000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, - cache_config: Optional[CacheConfig] = None, + cache_config: CacheConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -222,8 +222,8 @@ class Exaone4DecoderLayer(nn.Module): def __init__( self, config: Exaone4Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -277,7 +277,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: residual = hidden_states @@ -356,11 +356,11 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -523,9 +523,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: model_output = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -534,7 +534,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 211a9120789e..25429836b9ed 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -23,7 +23,7 @@ import math from collections.abc import Iterable from itertools import islice -from typing import Optional, Union +from typing import TypeAlias import torch from torch import nn @@ -65,7 +65,7 @@ maybe_prefix, ) -FalconConfig = Union[HF_FalconConfig, RWConfig] +FalconConfig: TypeAlias = HF_FalconConfig | RWConfig def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: @@ -95,8 +95,8 @@ class FalconAttention(nn.Module): def __init__( self, config: FalconConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -226,7 +226,7 @@ class FalconMLP(nn.Module): def __init__( self, config: FalconConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ): super().__init__() hidden_size = config.hidden_size @@ -265,8 +265,8 @@ class FalconDecoderLayer(nn.Module): def __init__( self, config: FalconConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -401,9 +401,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -517,8 +517,8 @@ def forward( self, input_ids: torch.LongTensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: hidden_states = self.transformer( input_ids, positions, intermediate_tensors, inputs_embeds @@ -528,7 +528,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 8af08711038d..4e0b6b52fc64 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -3,7 +3,7 @@ """Inference-only FalconH1 model.""" from collections.abc import Iterable -from typing import Optional +from itertools import islice import torch from torch import nn @@ -51,7 +51,7 @@ class FalconH1MLP(nn.Module): def __init__( self, config: FalconH1Config, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, ) -> None: super().__init__() @@ -90,9 +90,9 @@ class FalconH1SSMDecoderLayer(nn.Module): def __init__( self, config: FalconH1Config, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -187,7 +187,7 @@ def _init_mup_vector(self): def forward( self, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, **kwargs, ): output = torch.empty_like(hidden_states) @@ -203,8 +203,8 @@ class FalconH1AttentionDecoderLayer(nn.Module): def __init__( self, config: FalconH1Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -300,7 +300,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, **kwargs, ): hidden_states = self.self_attention( @@ -325,9 +325,9 @@ def __init__( self, config: FalconH1Config, layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -466,8 +466,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -480,8 +480,7 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer( positions=positions, hidden_states=hidden_states, @@ -610,8 +609,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, ): hidden_states = self.model( @@ -626,7 +625,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/flex_olmo.py b/vllm/model_executor/models/flex_olmo.py new file mode 100644 index 000000000000..11d0949a798a --- /dev/null +++ b/vllm/model_executor/models/flex_olmo.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only FlexOlmo model compatible with HuggingFace weights.""" + +import torch +from torch import nn + +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.models.olmoe import OlmoeAttention, OlmoeForCausalLM +from vllm.transformers_utils.configs import FlexOlmoConfig + +logger = init_logger(__name__) + + +class FlexOlmoAttention(OlmoeAttention): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + hf_config = vllm_config.model_config.hf_config + assert isinstance(hf_config, FlexOlmoConfig) + + self.k_norm = RMSNorm( + self.total_num_kv_heads * self.head_dim, eps=hf_config.rms_norm_eps + ) + self.q_norm = RMSNorm( + self.total_num_heads * self.head_dim, eps=hf_config.rms_norm_eps + ) + + +class FlexOlmoMoE(nn.Module): + """A tensor-parallel MoE implementation for FlexOlmo that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + hf_config = vllm_config.model_config.hf_config + assert isinstance(hf_config, FlexOlmoConfig) + + tp_size = get_tensor_model_parallel_world_size() + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear( + hf_config.hidden_size, + hf_config.num_experts, + bias=False, + return_bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + # Gate always runs at half / full precision for now. + self.experts = FusedMoE( + num_experts=hf_config.num_experts, + top_k=hf_config.num_experts_per_tok, + hidden_size=hf_config.hidden_size, + intermediate_size=hf_config.intermediate_size, + reduce_results=True, + renormalize=False, + quant_config=None, + tp_size=tp_size, + prefix=f"{prefix}.experts", + ) + + self.top_k = hf_config.num_experts_per_tok + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states) + # Warning: The experts mutate the hidden state input! This messes up + # basic things like the residual stream. + final_hidden_states = self.experts( + hidden_states=hidden_states.detach().clone(), + router_logits=router_logits.float(), + ) + + return final_hidden_states.view(orig_shape) + + +class FlexOlmoDecoderLayer(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + hf_config = vllm_config.model_config.hf_config + assert isinstance(hf_config, FlexOlmoConfig) + + self.self_attn = FlexOlmoAttention( + vllm_config=vllm_config, prefix=f"{prefix}.self_attn" + ) + self.post_attention_layernorm = RMSNorm( + hf_config.hidden_size, eps=hf_config.rms_norm_eps + ) + self.post_feedforward_layernorm = RMSNorm( + hf_config.hidden_size, eps=hf_config.rms_norm_eps + ) + + self.mlp = FlexOlmoMoE(vllm_config=vllm_config, prefix=f"{prefix}.mlp") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # Attention block. + residual = hidden_states + hidden_states = self.self_attn(positions, hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + residual + + # MLP block. + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + return hidden_states, None + + +class FlexOlmoForCausalLM(OlmoeForCausalLM): + fall_back_to_pt_during_load = False + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = FlexOlmoDecoderLayer, + ): + super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 83572563c15e..005fac4b1f05 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -20,7 +20,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional +from typing import Annotated, Literal import torch import torch.nn as nn @@ -87,7 +87,7 @@ def get_hf_processor(self, **kwargs: object): def get_image_processor(self, **kwargs: object) -> FuyuImageProcessor: return self.get_hf_processor(**kwargs).image_processor - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": 1} def get_image_feature_grid_size( @@ -142,7 +142,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) @@ -271,7 +271,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return None @@ -305,7 +305,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[FuyuImagePatchInputs]: + ) -> FuyuImagePatchInputs | None: image_patches = kwargs.pop("image_patches", None) patches_per_image = kwargs.pop("patches_per_image", None) @@ -344,8 +344,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ): if intermediate_tensors is not None: @@ -362,7 +362,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.language_model.logits_processor( self.language_model.lm_head, hidden_states ) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index b152f52223cf..46b111f4d939 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -20,7 +20,6 @@ from collections.abc import Iterable from functools import cache from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -59,8 +58,8 @@ @cache def _get_gemma_act_fn( - hidden_act: Optional[str], - hidden_activation: Optional[str], + hidden_act: str | None, + hidden_activation: str | None, ) -> nn.Module: if hidden_activation is None: if hidden_act is not None: @@ -92,9 +91,9 @@ def __init__( self, hidden_size: int, intermediate_size: int, - hidden_act: Optional[str] = None, - hidden_activation: Optional[str] = None, - quant_config: Optional[QuantizationConfig] = None, + hidden_act: str | None = None, + hidden_activation: str | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -130,8 +129,8 @@ def __init__( head_dim: int, max_position_embeddings: int = 8192, rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -207,8 +206,8 @@ class GemmaDecoderLayer(nn.Module): def __init__( self, config: GemmaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -241,7 +240,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -301,9 +300,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -406,9 +405,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -417,7 +416,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 2d26edcf6609..66c9b774f174 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -18,7 +18,6 @@ # limitations under the License. from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -66,7 +65,7 @@ def __init__( intermediate_size: int, hidden_act: str, hidden_activation: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -100,9 +99,9 @@ def __init__( head_dim: int, max_position_embeddings: int, rope_theta: float, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - attn_logits_soft_cap: Optional[float] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + attn_logits_soft_cap: float | None = None, prefix: str = "", ) -> None: super().__init__() @@ -183,8 +182,8 @@ class Gemma2DecoderLayer(nn.Module): def __init__( self, config: Gemma2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -225,7 +224,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states @@ -284,11 +283,11 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -406,9 +405,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -417,7 +416,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 9fa8e1c78b12..80ec40f478c6 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -17,7 +17,6 @@ # limitations under the License. from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch import torch.nn.functional as F @@ -66,7 +65,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_activation: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -108,9 +107,9 @@ def __init__( num_kv_heads: int, head_dim: int, max_position_embeddings: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - attn_logits_soft_cap: Optional[float] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + attn_logits_soft_cap: float | None = None, prefix: str = "", ) -> None: super().__init__() @@ -295,8 +294,8 @@ class Gemma3DecoderLayer(nn.Module): def __init__( self, config: Gemma3TextConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -336,7 +335,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: @@ -372,6 +371,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + quant_config=quant_config, prefix=f"{prefix}.embed_tokens", ) self.start_layer, self.end_layer, self.layers = make_layers( @@ -400,12 +400,12 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -442,6 +442,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: + # Revert +1 during llama.cpp conversion + # see: https://github.com/ggml-org/llama.cpp/blob/be7c3034108473beda214fd1d7c98fd6a7a3bdf5/convert_hf_to_gguf.py#L3397-L3400 + if ( + self.quant_config + and self.quant_config.get_name() == "gguf" + and name.endswith("norm.weight") + ): + loaded_weight -= 1 + if self.quant_config is not None and ( scale_name := self.quant_config.get_cache_scale(name) ): @@ -539,10 +548,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs ) @@ -551,7 +560,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 95b0b0dab5a1..7c628fe93ce3 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, Optional +from typing import Annotated, Any, Literal import torch from torch import nn @@ -82,7 +82,7 @@ def get_hf_config(self): def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor(Gemma3Processor, **kwargs) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def _resolve_image_kwargs( @@ -112,7 +112,7 @@ def get_num_crops( *, image_width: int, image_height: int, - processor: Optional[Gemma3Processor], + processor: Gemma3Processor | None, ) -> int: if processor is None: processor = self.get_hf_processor() @@ -182,7 +182,7 @@ def get_image_repl( *, image_width: int, image_height: int, - processor: Optional[Gemma3Processor], + processor: Gemma3Processor | None, ) -> PromptUpdateDetails[str]: if processor is None: processor = self.get_hf_processor() @@ -217,7 +217,7 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - processor: Optional[Gemma3Processor], + processor: Gemma3Processor | None, ) -> int: if processor is None: processor = self.get_hf_processor() @@ -256,7 +256,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) @@ -510,7 +510,7 @@ class Gemma3ForConditionalGeneration( ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<start_of_image>" @@ -555,7 +555,7 @@ def dtype(self): def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[Gemma3ImageInputs]: + ) -> Gemma3ImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) num_patches = kwargs.pop("num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -609,8 +609,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> IntermediateTensors: if intermediate_tensors is not None: @@ -692,7 +692,7 @@ def prepare_attn_masks( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index e4ea4256ebc2..f7a732e3a601 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -16,7 +16,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterable -from typing import Optional, Union import torch from torch import nn @@ -196,7 +195,7 @@ def __init__( laurel_rank: int, rms_norm_eps: float, *, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str, ) -> None: super().__init__() @@ -236,7 +235,7 @@ def __init__( intermediate_size: int, hidden_activation: str, activation_sparsity: float = 0.0, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -285,8 +284,8 @@ def __init__( num_kv_heads: int, head_dim: int, max_position_embeddings: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -412,8 +411,8 @@ class Gemma3nDecoderLayer(nn.Module): def __init__( self, config: Gemma3nTextConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -650,7 +649,7 @@ def get_per_layer_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tenso def get_per_layer_inputs( self, hidden_states_0: torch.Tensor, - per_layer_inputs: Optional[torch.Tensor], + per_layer_inputs: torch.Tensor | None, ) -> torch.Tensor: per_layer_projection = self.per_layer_model_projection(hidden_states_0) per_layer_projection = per_layer_projection.reshape( @@ -687,8 +686,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - per_layer_inputs: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, + per_layer_inputs: torch.Tensor | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: if inputs_embeds is not None: @@ -870,8 +869,8 @@ def fast_prefill_forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - per_layer_inputs: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, + per_layer_inputs: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: logits_indices_padded, num_logits_indices = None, None @@ -947,8 +946,8 @@ def normal_forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - per_layer_inputs: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, + per_layer_inputs: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: hidden_states, per_layer_inputs = self.self_decoder( @@ -990,13 +989,13 @@ def altup_unembed( def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - per_layer_inputs: Optional[torch.Tensor] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: torch.Tensor | None = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if self.fast_prefill_enabled: hidden_states = self.fast_prefill_forward( input_ids, @@ -1116,11 +1115,11 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, *, - per_layer_inputs: Optional[torch.Tensor] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: torch.Tensor | None = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, @@ -1134,7 +1133,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 0e69fcfd8feb..2b727a538bf2 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -58,7 +58,6 @@ from .utils import ( AutoWeightsLoader, WeightsMapper, - flatten_bn, init_vllm_registered_model, maybe_prefix, ) diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py index f25f50602e6c..d7fd2b109d24 100644 --- a/vllm/model_executor/models/glm4.py +++ b/vllm/model_executor/models/glm4.py @@ -24,7 +24,6 @@ """Inference-only GLM-4-0414 model compatible with HuggingFace weights.""" from collections.abc import Iterable -from typing import Optional, Union import torch from torch import nn @@ -56,12 +55,12 @@ def __init__( num_heads: int, num_kv_heads: int, max_position: int = 4096 * 32, - head_dim: Optional[int] = None, + head_dim: int | None = None, qkv_bias: bool = False, rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[tuple] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + rope_scaling: tuple | None = None, prefix: str = "", attn_type: str = AttentionType.DECODER, ) -> None: @@ -142,7 +141,7 @@ def __init__( self, vllm_config: VllmConfig, prefix: str = "", - config: Optional[Glm4Config] = None, + config: Glm4Config | None = None, ) -> None: super().__init__() @@ -189,7 +188,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -285,9 +284,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -296,7 +295,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 304e721fade5..9f1439e21ef7 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -27,18 +27,16 @@ """Inference-only GLM-4V model compatible with HuggingFace weights.""" import math -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial -from typing import Annotated, Any, Callable, Literal, Optional, Union +from typing import Annotated, Any, Literal, TypeAlias import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from packaging.version import Version from transformers import BatchFeature -from transformers import __version__ as TRANSFORMERS_VERSION from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig from transformers.models.glm4v.image_processing_glm4v import ( Glm4vImageProcessor, @@ -62,6 +60,7 @@ ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig @@ -100,7 +99,11 @@ init_vllm_registered_model, maybe_prefix, ) -from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model +from .vision import ( + conv3d_to_linear_weight, + get_vit_attn_backend, + run_dp_sharded_mrope_vision_model, +) logger = init_logger(__name__) @@ -140,7 +143,7 @@ class Glm4vImageEmbeddingInputs(TensorSchema): image_grid_thw: Annotated[torch.Tensor, TensorShape("n", 3)] -Glm4vImageInputs = Union[Glm4vImagePixelInputs, Glm4vImageEmbeddingInputs] +Glm4vImageInputs: TypeAlias = Glm4vImagePixelInputs | Glm4vImageEmbeddingInputs class Glm4vVideoPixelInputs(TensorSchema): @@ -176,7 +179,7 @@ class Glm4vVideoEmbeddingInputs(TensorSchema): video_grid_thw: Annotated[torch.Tensor, TensorShape("f", 3)] -Glm4vVideoInputs = Union[Glm4vVideoPixelInputs, Glm4vVideoEmbeddingInputs] +Glm4vVideoInputs: TypeAlias = Glm4vVideoPixelInputs | Glm4vVideoEmbeddingInputs # ==== Vision Encoder ==== # @@ -187,7 +190,7 @@ def __init__( in_features: int, hidden_features: int, bias: bool = False, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ): @@ -244,9 +247,10 @@ def __init__( embed_dim: int, num_heads: int, projection_size: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -287,6 +291,7 @@ def __init__( self.attn_backend = get_vit_attn_backend( head_size=self.hidden_size_per_attention_head, dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) self.use_upstream_fa = False @@ -294,6 +299,7 @@ def __init__( maybe_get_vit_flash_attn_backend( self.attn_backend, self.use_upstream_fa, + attn_backend_override=attn_backend_override, ) ) @@ -334,8 +340,8 @@ def forward( x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + max_seqlen: int | None = None, # Only used for Flash Attention + seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -413,10 +419,11 @@ def __init__( dim: int, num_heads: int, mlp_hidden_dim: int, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, + norm_layer: Callable[[int], nn.Module] | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -430,6 +437,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.attn", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) self.mlp = Glm4vVisionMLP( dim, @@ -445,8 +453,8 @@ def forward( x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + max_seqlen: int | None = None, # Only used for Flash Attention + seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: x_attn = self.attn( self.norm1(x), @@ -475,18 +483,15 @@ def __init__( self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = nn.Conv3d( - in_channels, + self.proj = ReplicatedLinear( + in_channels * math.prod(kernel_size), hidden_size, - kernel_size=kernel_size, - stride=kernel_size, bias=True, + return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) - x = self.proj(x).view(L, self.hidden_size) + x = self.proj(x) return x @@ -495,7 +500,7 @@ def __init__( self, d_model: int, context_dim: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, prefix: str = "", use_data_parallel: bool = False, @@ -693,9 +698,10 @@ def __init__( self, vision_config: Glm4vVisionConfig, norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() @@ -731,6 +737,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.blocks.{layer_idx}", use_data_parallel=self.use_data_parallel, + attn_backend_override=attn_backend_override, ) for layer_idx in range(depth) ] @@ -759,7 +766,9 @@ def __init__( ) self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype() + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( torch.get_default_dtype() @@ -809,7 +818,7 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor, - ) -> tuple[Optional[int], Optional[list[int]]]: + ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() if ( @@ -880,6 +889,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params: set[str] = set() for name, loaded_weight in weights: + if name.endswith("patch_embed.proj.weight"): + loaded_weight = conv3d_to_linear_weight(loaded_weight) + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -904,7 +916,7 @@ def get_hf_config(self): def get_tokenizer(self): return self.ctx.tokenizer - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None, "video": 1} def get_image_processor(self, **kwargs: object) -> Glm4vImageProcessor: @@ -1141,7 +1153,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -1177,7 +1189,7 @@ def _get_dummy_videos( height: int, num_frames: int, num_videos: int, - overrides: Optional[VideoDummyOptions] = None, + overrides: VideoDummyOptions | None = None, ) -> list[VideoItem]: if overrides: if overrides.num_frames: @@ -1261,14 +1273,7 @@ def _call_hf_processor( video_mm_data = dict() video_mm_data["videos"] = [[video_array]] - # backward compatibility for Transformers 4.55 unuse_metadata = ["do_sample_frames"] - if ( - not hasattr(VideoMetadata, "frames_indices") - and "frames_indices" in metadata - ): - unuse_metadata.append("frames_indices") - video_mm_data["video_metadata"] = [ [ VideoMetadata( @@ -1287,24 +1292,11 @@ def _call_hf_processor( mm_kwargs=video_mm_kwargs, tok_kwargs=tok_kwargs, ) - if not video_mm_kwargs["do_sample_frames"] and Version( - TRANSFORMERS_VERSION - ) < Version("4.56.0"): - # Transformers v4.55 has incorrect timestamps issue for - # skip sampling. We construct the placeholder manually to - # get placeholders with correct timestamps. - placeholder = self.info._construct_video_placeholder( - video_array, - metadata, - video_outputs["video_grid_thw"].squeeze(0), - ) - video_placeholder = processor.tokenizer.decode(placeholder) - else: - input_ids = video_outputs.pop("input_ids") - input_ids[input_ids == processor.image_token_id] = ( - processor.video_token_id - ) - video_placeholder = processor.tokenizer.batch_decode(input_ids)[0] + input_ids = video_outputs.pop("input_ids") + input_ids[input_ids == processor.image_token_id] = ( + processor.video_token_id + ) + video_placeholder = processor.tokenizer.batch_decode(input_ids)[0] prompt = prompt.replace( "<|begin_of_video|><|video|><|end_of_video|>", video_placeholder, @@ -1419,7 +1411,7 @@ class Glm4vForConditionalGeneration( supports_encoder_tp_data = True @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|begin_of_image|><|image|><|end_of_image|>" if modality.startswith("video"): @@ -1437,12 +1429,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.visual = Glm4vVisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-5), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), use_data_parallel=self.use_data_parallel, + attn_backend_override=attn_backend_override, ) if config.model_type == "glm4v": @@ -1465,7 +1463,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[Glm4vImageInputs]: + ) -> Glm4vImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1489,7 +1487,7 @@ def _parse_and_validate_image_input( def _parse_and_validate_video_input( self, **kwargs: object - ) -> Optional[Glm4vVideoInputs]: + ) -> Glm4vVideoInputs | None: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -1594,7 +1592,7 @@ def get_language_model(self) -> torch.nn.Module: def get_multimodal_embeddings( self, **kwargs: object - ) -> Optional[MultiModalEmbeddings]: + ) -> MultiModalEmbeddings | None: mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return None @@ -1608,21 +1606,21 @@ def get_multimodal_embeddings( for modality in mm_input_by_modality: multimodal_input = mm_input_by_modality[modality] if modality == "image": - vision_embeddings = self._process_image_input(multimodal_input) - multimodal_embeddings += vision_embeddings + image_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += tuple(image_embeddings) if modality == "video": video_embeddings = self._process_video_input(multimodal_input) - multimodal_embeddings += video_embeddings + multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: """Run forward pass for GLM-4V. Args: @@ -1652,7 +1650,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 5db6f297dbf2..a53f52852c6a 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -26,7 +26,7 @@ import typing from collections.abc import Callable, Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -42,7 +42,7 @@ ) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -52,7 +52,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -82,7 +81,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, reduce_results: bool = True, prefix: str = "", ) -> None: @@ -119,7 +118,7 @@ class Glm4MoE(nn.Module): def __init__( self, config: Glm4MoeConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", enable_eplb: bool = False, ): @@ -176,46 +175,29 @@ def __init__( reduce_results=False, prefix=f"{prefix}.shared_experts", ) - self.experts = SharedFusedMoE( - shared_experts=self.shared_experts, - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func="sigmoid", - # we do scaling outside, set factor to 1.0 to avoid double mul - routed_scaling_factor=1.0, - e_score_correction_bias=self.gate.e_score_correction_bias, - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts, - ) else: - self.experts = FusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func="sigmoid", - # we do scaling outside, set factor to 1.0 to avoid double mul - routed_scaling_factor=1.0, - e_score_correction_bias=self.gate.e_score_correction_bias, - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts, - ) + self.shared_experts = None + + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func="sigmoid", + # we do scaling outside, set factor to 1.0 to avoid double mul + routed_scaling_factor=1.0, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape @@ -252,14 +234,14 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 131072, - head_dim: Optional[int] = None, + head_dim: int | None = None, rms_norm_eps: float = 1e-05, qkv_bias: bool = False, use_qk_norm: bool = False, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -352,8 +334,8 @@ class Glm4MoeDecoderLayer(nn.Module): def __init__( self, config: Glm4MoeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", enable_eplb: bool = False, ) -> None: @@ -413,7 +395,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states @@ -480,9 +462,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -522,7 +504,7 @@ def make_empty_intermediate_tensors( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return FusedMoE.make_expert_params_mapping( + return SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", @@ -677,7 +659,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace self.num_expert_groups = config.n_group - self.moe_layers: list[FusedMoE] = [] + self.moe_layers: list[SharedFusedMoE] = [] example_moe = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): @@ -722,9 +704,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -733,7 +715,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits @@ -747,7 +729,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_spec_layer_idx_from_weight_name( config: Glm4MoeConfig, weight_name: str -) -> Optional[int]: +) -> int | None: if hasattr(config, "num_nextn_predict_layers") and ( config.num_nextn_predict_layers > 0 ): diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py index beb40632246c..9fb1be7ba45c 100644 --- a/vllm/model_executor/models/glm4_moe_mtp.py +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -24,7 +24,6 @@ """Inference-only GLM-4.5 MTP model compatible with HuggingFace weights.""" from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn @@ -52,7 +51,7 @@ def __init__( self, config: PretrainedConfig, prefix: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -72,8 +71,8 @@ def __init__( self, config: PretrainedConfig, prefix: str, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -94,7 +93,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, previous_hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, spec_step_index: int = 0, ) -> torch.Tensor: assert inputs_embeds is not None @@ -149,7 +148,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, previous_hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, spec_step_idx: int = 0, ) -> torch.Tensor: if inputs_embeds is None: @@ -192,8 +191,8 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, spec_step_idx: int = 0, ) -> torch.Tensor: hidden_states = self.model( @@ -205,7 +204,7 @@ def compute_logits( self, hidden_states: torch.Tensor, spec_step_idx: int = 0, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.model.compute_logits(hidden_states, spec_step_idx) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index a5c3ce0e6bf7..2de1e4810952 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -5,16 +5,17 @@ # https://github.com/zai-org/CogAgent """Inference-only CogAgent model compatible with THUDM weights.""" +import itertools from argparse import Namespace from collections.abc import Mapping, Sequence -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal import torch from torch import nn from torch.nn import LayerNorm from torchvision import transforms from torchvision.transforms import InterpolationMode -from transformers import BatchFeature, PreTrainedTokenizer, TensorType +from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput @@ -54,6 +55,7 @@ from .interfaces import ( MultiModalEmbeddings, SupportsLoRA, + SupportsMRoPE, SupportsMultiModal, SupportsPP, ) @@ -107,7 +109,7 @@ class EVA2CLIPAttention(nn.Module): def __init__( self, config, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -150,7 +152,7 @@ class EVA2CLIPMLP(nn.Module): def __init__( self, config, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -180,7 +182,7 @@ class EVA2CLIPTransformerLayer(nn.Module): def __init__( self, config, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -209,7 +211,7 @@ class EVA2CLIPTransformer(nn.Module): def __init__( self, config, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -235,7 +237,7 @@ def __init__( self, config, in_features, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): """ @@ -315,7 +317,7 @@ class EVA2CLIPModel(nn.Module): def __init__( self, config, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -414,9 +416,9 @@ def __init__( def __call__( self, - text: Optional[Union[TextInput, list[TextInput]]] = None, - images: Optional[Union[ImageInput, list[ImageInput]]] = None, - return_tensors: Optional[Union[str, TensorType]] = None, + text: TextInput | list[TextInput] | None = None, + images: ImageInput | list[ImageInput] | None = None, + return_tensors: str | TensorType | None = None, ) -> BatchFeature: if text is None: text = [] @@ -456,7 +458,7 @@ def get_hf_processor(self, **kwargs: object) -> GLM4VProcessor: **kwargs, ) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": 1} def get_num_image_tokens(self) -> int: @@ -485,7 +487,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: hf_config = self.info.get_hf_config() vision_config = hf_config.vision_config @@ -554,7 +556,9 @@ def get_replacement(item_idx: int): info=GLM4VProcessingInfo, dummy_inputs=GLM4VDummyInputsBuilder, ) -class GLM4VForCausalLM(ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP): +class GLM4VForCausalLM( + ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE +): merge_by_field_config = True packed_modules_mapping = { @@ -574,7 +578,7 @@ def get_mm_mapping(self) -> MultiModelKeys: ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|begin_of_image|><|endoftext|><|end_of_image|>" @@ -597,7 +601,7 @@ def __init__( def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[GLMVImagePixelInputs]: + ) -> GLMVImagePixelInputs | None: pixel_values = kwargs.pop("pixel_values", None) if pixel_values is not None: @@ -611,10 +615,153 @@ def _parse_and_validate_image_input( return None def _process_image_input(self, image_input: GLMVImagePixelInputs) -> torch.Tensor: - pixel_values = image_input["data"].to(dtype=self.config.torch_dtype) + pixel_values = image_input["data"].to(dtype=self.config.dtype) return self.transformer.vision(pixel_values) + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, + context_len: int = 0, + seq_len: int | None = None, + second_per_grid_ts: list[float] | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value for GLM4V.""" + + image_token_id = hf_config.image_token_id + video_start_token_id = hf_config.video_start_token_id + video_end_token_id = hf_config.video_end_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + llm_pos_ids_list: list = [] + + if not (image_grid_thw is None and video_grid_thw is None): + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + + input_token_type: list[str] = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if (token == image_token_id) and (video_check_flg is False): + input_token_type.append("image") + elif (token == image_token_id) and (video_check_flg is True): + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group: list[tuple[str, int, int]] = [] + for key, group_iter in itertools.groupby( + enumerate(input_token_type), lambda x: x[1] + ): + group_list = list(group_iter) + start_index = group_list[0][0] + end_index = group_list[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + video_frame_num = 1 + mm_data_idx = 0 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) + if modality_type == "image": + t, h, w = ( + image_grid_thw[mm_data_idx][0], + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx + ) + mm_data_idx += 1 + + elif modality_type == "video": + t, h, w = ( + video_frame_num, + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + + for t_idx in range(llm_grid_t): + t_index = ( + torch.tensor(t_idx) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(1, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(1, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx + ) + + mm_data_idx += 1 + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + video_frame_num = 1 + + else: + text_len = len(input_tokens) + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:seq_len] + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + return llm_positions, mrope_position_delta + def get_language_model(self) -> torch.nn.Module: return self.transformer @@ -632,10 +779,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 53d6026c5938..6d99d02a32be 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -22,7 +22,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -51,7 +50,7 @@ from vllm.sequence import IntermediateTensors from ..layers.pooler import DispatchPooler, Pooler -from .interfaces import SupportsPP +from .interfaces import SupportsCrossEncoding, SupportsPP from .utils import ( AutoWeightsLoader, is_pp_missing_parameter, @@ -65,8 +64,8 @@ class GPT2Attention(nn.Module): def __init__( self, config: GPT2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -118,7 +117,7 @@ def __init__( self, intermediate_size: int, config: GPT2Config, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -150,8 +149,8 @@ class GPT2Block(nn.Module): def __init__( self, config: GPT2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -221,9 +220,9 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor], - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is None: inputs_embeds = self.get_input_embeddings(input_ids) @@ -301,9 +300,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.transformer( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -312,7 +311,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits @@ -322,7 +321,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loader.load_weights(weights) -class GPT2ForSequenceClassification(nn.Module): +class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding): """GPT2 Model for sequence classification. This class expands GPT2Model with pooling and score functions - last token @@ -354,11 +353,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), - "classify": Pooler.for_classify(pooler_config, classifier=self.score), + "token_classify": Pooler.for_token_classify( + pooler_config, classifier=self.score + ), + "classify": Pooler.for_classify( + pooler_config, classifier=self.score, act_fn="classify" + ), + "score": Pooler.for_classify( + pooler_config, classifier=self.score, act_fn="score" + ), } ) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) return loader.load_weights(weights) @@ -367,8 +376,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: hidden_states = self.transformer( input_ids=input_ids, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index b6d3d8f3f2e6..f2c8e2aeb822 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -23,7 +23,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -62,8 +61,8 @@ class GPTBigCodeAttention(nn.Module): def __init__( self, config: GPTBigCodeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -133,7 +132,7 @@ def __init__( self, intermediate_size: int, config: GPTBigCodeConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -165,8 +164,8 @@ class GPTBigCodeBlock(nn.Module): def __init__( self, config: GPTBigCodeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -243,9 +242,9 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is None: inputs_embeds = self.get_input_embeddings(input_ids) @@ -326,9 +325,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.transformer( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -337,7 +336,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 5428512dec19..1777fd3583c3 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -21,7 +21,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -64,8 +63,8 @@ class GPTJAttention(nn.Module): def __init__( self, config: GPTJConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -130,7 +129,7 @@ def __init__( self, intermediate_size: int, config: GPTJConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ): super().__init__() hidden_size = config.n_embd @@ -157,8 +156,8 @@ class GPTJBlock(nn.Module): def __init__( self, config: GPTJConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -218,9 +217,9 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -322,9 +321,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.transformer( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -333,7 +332,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states, self.lm_head.bias) return logits diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 8278ae03d88a..2f638acaa2b6 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -21,7 +21,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -61,8 +60,8 @@ class GPTNeoXAttention(nn.Module): def __init__( self, config: GPTNeoXConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -125,7 +124,7 @@ class GPTNeoXMLP(nn.Module): def __init__( self, config: GPTNeoXConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ): super().__init__() self.dense_h_to_4h = ColumnParallelLinear( @@ -151,8 +150,8 @@ class GPTNeoXLayer(nn.Module): def __init__( self, config: GPTNeoXConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -232,9 +231,9 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -320,9 +319,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.gpt_neox( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -331,7 +330,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.embed_out, hidden_states) return logits diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 17f911435079..846c8e7669be 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Optional import torch import torch.distributed as dist @@ -12,6 +11,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import ( + get_dp_group, get_ep_group, get_pp_group, get_tensor_model_parallel_rank, @@ -19,6 +19,7 @@ tensor_model_parallel_all_gather, ) from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -33,7 +34,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import cdiv -from .interfaces import SupportsEagle3, SupportsPP +from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, WeightsMapper, @@ -49,8 +50,8 @@ class OAIAttention(nn.Module): def __init__( self, config: GptOssConfig, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, + quant_config: QuantizationConfig | None = None, + cache_config: CacheConfig | None = None, prefix: str = "", ): super().__init__() @@ -208,7 +209,7 @@ def forward( self, hidden_states: torch.Tensor, positions: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> torch.Tensor: # Self Attention if residual is None: @@ -217,6 +218,7 @@ def forward( else: hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.attn(hidden_states, positions) + # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) output = self.mlp(hidden_states) @@ -260,8 +262,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -305,8 +307,13 @@ def _load_weights_mxfp4( use_ep = self.parallel_config.enable_expert_parallel num_experts = self.config.num_local_experts - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() + # In MoE, we need to flatten the tensor parallel size across the data + # parallel size when EP is disabled. + tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp( + tp_size=get_tensor_model_parallel_world_size(), + dp_size=get_dp_group().world_size, + dp_rank=get_dp_group().rank_in_group, + ) intermediate_size = self.config.intermediate_size intermediate_size_block = intermediate_size // mxfp4_block @@ -488,8 +495,13 @@ def _load_weights_other( use_ep = self.parallel_config.enable_expert_parallel - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() + # In MoE, we need to flatten the tensor parallel size across the data + # parallel size when EP is disabled. + tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp( + tp_size=get_tensor_model_parallel_world_size(), + dp_size=get_dp_group().world_size, + dp_rank=get_dp_group().rank_in_group, + ) intermediate_size = self.config.intermediate_size per_rank_intermediate_size = cdiv(intermediate_size, tp_size) @@ -627,7 +639,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ) -class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3): +class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} hf_to_vllm_mapper = WeightsMapper( @@ -687,8 +699,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: return self.model(input_ids, positions, intermediate_tensors, inputs_embeds) @@ -696,6 +708,17 @@ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, weight scales, activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_local_experts, + num_redundant_experts=0, + ) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index e9bc592c0797..5fc8718ca75e 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -26,7 +26,7 @@ from collections.abc import Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -73,7 +73,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, prefix: str = "", ) -> None: @@ -113,11 +113,11 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, - cache_config: Optional[CacheConfig] = None, + cache_config: CacheConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -197,8 +197,8 @@ class GraniteDecoderLayer(nn.Module): def __init__( self, config: GraniteConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -323,11 +323,11 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -480,15 +480,15 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: model_output = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) return model_output - def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index 82bceaf3ed01..043b1406bd37 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -26,7 +26,7 @@ import math from collections.abc import Iterable, Mapping -from typing import Annotated, Optional, Union +from typing import Annotated import torch import torch.nn.functional as F @@ -92,7 +92,7 @@ class GraniteSpeechAudioInputs(TensorSchema): class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"audio": 1} # There is no limit to the maximum number of audio tokens that can be @@ -196,7 +196,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) audio_overrides = mm_options.get("audio") if mm_options else None @@ -222,7 +222,7 @@ def __init__( self, config: PretrainedConfig, cache_config: CacheConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -279,7 +279,7 @@ class GraniteSpeechConformerFeedForward(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -479,7 +479,7 @@ def __init__( self, config: PretrainedConfig, prefix: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ): super().__init__() self.config = config @@ -561,7 +561,7 @@ class GraniteSpeechForConditionalGeneration( } @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("audio"): return "<|audio|>" @@ -606,7 +606,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def _parse_and_validate_audio_input( self, **kwargs: object, - ) -> Optional[GraniteSpeechAudioInputs]: + ) -> GraniteSpeechAudioInputs | None: input_features = kwargs.pop("input_features", None) input_features_mask = kwargs.pop("input_features_mask", None) audio_embed_sizes = kwargs.pop("audio_embed_sizes", None) @@ -763,9 +763,9 @@ def get_multimodal_embeddings( def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, *, - is_multimodal: Optional[torch.Tensor] = None, + is_multimodal: torch.Tensor | None = None, # Multi-modal token ID may exceed vocab size handle_oov_mm_token: bool = True, ) -> torch.Tensor: @@ -784,10 +784,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None @@ -799,7 +799,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights( diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 4711ed05c587..e683f30805f3 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -26,7 +26,7 @@ from collections.abc import Iterable from itertools import islice -from typing import Any, Optional +from typing import Any import torch from torch import nn @@ -79,9 +79,9 @@ def __init__( top_k: int, hidden_size: int, intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + tp_size: int | None = None, is_sequence_parallel=False, prefix: str = "", ): @@ -143,10 +143,10 @@ def __init__( num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - attention_multiplier: Optional[float] = None, + rope_scaling: dict[str, Any] | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + attention_multiplier: float | None = None, prefix: str = "", ) -> None: super().__init__() @@ -330,8 +330,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -557,15 +557,15 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index f877dc576427..1bb7f4e9b802 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -4,7 +4,6 @@ # Added by the IBM Team, 2025 from collections.abc import Iterable -from typing import Optional import torch from torch import nn @@ -50,9 +49,9 @@ def __init__( self, config: GraniteMoeHybridConfig, layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -105,7 +104,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, **kwargs, ): residual = hidden_states @@ -139,9 +138,9 @@ def __init__( self, config: GraniteMoeHybridConfig, layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -183,7 +182,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -218,9 +217,9 @@ class GraniteMoeHybridAttention(nn.Module): def __init__( self, config: GraniteMoeHybridConfig, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -331,6 +330,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_config = vllm_config.lora_config self.config = config + self.quant_config = quant_config lora_vocab = ( (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config @@ -374,8 +374,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -406,6 +406,33 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + # layers.0.block_sparse_moe.expert_0.input_linear.input_scale + ckpt_gate_proj_name = "gate_proj" + ckpt_down_proj_name = "down_proj" + ckpt_up_proj_name = "up_proj" + num_experts = self.config.num_local_experts + + return [ + # (param_name, weight_name, expert_id, shard_id) + ( + "block_sparse_moe.experts.w13_" + if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] + else "block_sparse_moe.experts.w2_", + f"block_sparse_moe.experts.{expert_id}.{weight_name}.", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -415,6 +442,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() def _load(n, p): param = params_dict[n] @@ -436,10 +464,56 @@ def _load_expert(n, p, name, shard_id, expert_id): weight_loader(param, p, name, shard_id=shard_id, expert_id=expert_id) loaded_params.add(n) + def _load_quant_expert(name, loaded_weight): + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + + if weight_name not in name: + continue + + name_mapped = name.replace(weight_name, param_name) + + # Skip layers on other devices. + if is_pp_missing_parameter(name_mapped, self): + continue + + param = params_dict[name_mapped] + weight_loader = param.weight_loader + success = False + + if weight_loader is not None: + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + + if success: + return name_mapped + return None + for n, p in weights: if "A_log" in n: n = n.replace("A_log", "A") + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(n) + ): + # Loading kv cache quantization scales + loaded_weight = p + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) + _load(scale_name, loaded_weight) + loaded_params.add(scale_name) + continue + + if _load_quant_expert(n, p): + continue + # Logic analogous to: https://github.com/vllm-project/vllm/blob/f49e5aff11c986ed4d45202b1716c5d74786efa9/vllm/model_executor/models/granitemoeshared.py#L215 # Mapping different experts' layout: # from HF (input_linear, output_linear, router) @@ -614,8 +688,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, ): hidden_states = self.model( @@ -627,7 +701,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index 93302821ca68..e222109f2a94 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -8,7 +8,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional import torch from torch import nn @@ -41,7 +40,7 @@ class GraniteMoeSharedMLP(nn.Module): def __init__( self, config: GraniteMoeSharedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -80,8 +79,8 @@ class GraniteMoeSharedDecoderLayer(nn.Module): def __init__( self, config: GraniteMoeSharedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -198,8 +197,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -321,15 +320,15 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index ac78dd9e753a..181c4ed2dca5 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Set -from typing import Optional, Union import numpy as np import torch @@ -62,7 +61,7 @@ def _find_array( arr: np.ndarray, target: np.ndarray, start_idx: int = 0, - end_idx: Optional[int] = None, + end_idx: int | None = None, ) -> int: """ Find the first occurrence of `target` in `arr` starting from @@ -149,27 +148,22 @@ def get_supported_tasks(self) -> Set[PoolingTask]: def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: return PoolingParamsUpdate(requires_token_ids=True) - def forward_one( + def forward( self, - hidden_states: torch.Tensor, - prompt_len: Optional[torch.Tensor] = None, - instr_len: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert prompt_len is None or prompt_len == hidden_states.shape[0], ( - "partial prefill not supported with MEAN pooling" + hidden_states: torch.Tensor | list[torch.Tensor], + pooling_metadata: PoolingMetadata, + ) -> list[torch.Tensor] | torch.Tensor: + prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) + instr_lens = torch.tensor( + [ + self._get_instruction_len(token_ids.cpu().numpy()) + for token_ids in get_prompt_token_ids(pooling_metadata) + ], + device="cpu", ) - return hidden_states[instr_len:].mean(dim=0, dtype=torch.float32) - - def forward_all( - self, - hidden_states: torch.Tensor, - prompt_lens: torch.Tensor, - instr_lens: torch.Tensor, - ) -> Union[list[torch.Tensor], torch.Tensor]: offset = 0 pooled_data = list[torch.Tensor]() - for prompt_len, instr_len in zip(prompt_lens, instr_lens): pooled_data.append( hidden_states[offset + instr_len : offset + prompt_len].mean( @@ -180,30 +174,6 @@ def forward_all( return pooled_data - def forward( - self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, - ) -> Union[list[torch.Tensor], torch.Tensor]: - prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) - instr_lens = torch.tensor( - [ - self._get_instruction_len(token_ids.cpu().numpy()) - for token_ids in get_prompt_token_ids(pooling_metadata) - ], - device=prompt_lens.device, - ) - - if isinstance(hidden_states, list): - return [ - self.forward_one(h, prompt_len, instr_len) - for h, prompt_len, instr_len in zip( - hidden_states, prompt_lens, instr_lens - ) - ] - - return self.forward_all(hidden_states, prompt_lens, instr_lens) - class GritLMPooler(Pooler): def __init__(self, model_config: ModelConfig): @@ -269,7 +239,7 @@ def __init__( if pooler_config is not None: self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_embed": Pooler.for_token_embed(pooler_config), "embed": GritLMPooler(vllm_config.model_config), } ) diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index f4139685b79f..d77a0bc2993a 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -25,7 +25,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch import torch.nn.functional as F @@ -86,9 +85,9 @@ def __init__( top_k: int, hidden_size: int, intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + tp_size: int | None = None, prefix: str = "", ): super().__init__() @@ -137,8 +136,8 @@ def __init__( num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", config=None, # Added config parameter ) -> None: @@ -223,8 +222,8 @@ class Grok1DecoderLayer(nn.Module): def __init__( self, config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -273,7 +272,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -351,9 +350,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -544,9 +543,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -555,7 +554,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index d7ee0fd8fd37..81c6b34bd6ce 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -9,7 +9,6 @@ # Licensed under Apache 2.0 License [see LICENSE for details] # -------------------------------------------------------- from collections.abc import Mapping, Sequence -from typing import Optional, Union import torch from PIL import Image @@ -67,7 +66,7 @@ def get_h2ovl_target_ratios( min_num: int, max_num: int, *, - prior_aspect_ratio: Optional[tuple[int, int]], + prior_aspect_ratio: tuple[int, int] | None, ) -> list[tuple[int, int]]: target_ratios = get_internvl_target_ratios(min_num, max_num) @@ -170,7 +169,7 @@ def _preprocess_image( min_num: int, max_num: int, use_thumbnail: bool, - prior_aspect_ratio: Optional[tuple[int, int]], + prior_aspect_ratio: tuple[int, int] | None, ) -> tuple[torch.Tensor, tuple[int, int]]: target_ratios = get_h2ovl_target_ratios( min_num, @@ -244,10 +243,10 @@ def __init__( config: PretrainedConfig, tokenizer: AnyTokenizer, *, - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, - use_msac: Optional[bool] = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, + use_msac: bool | None = None, ) -> None: super().__init__( config, @@ -270,7 +269,7 @@ def image_token_id(self) -> int: def get_image_repl( self, feature_size: int, - num_patches: Optional[int], + num_patches: int | None, ) -> PromptUpdateDetails[str]: repl_features = IMG_CONTEXT * feature_size repl_full = IMG_START + repl_features + IMG_END @@ -280,10 +279,10 @@ def get_image_repl( def resolve_min_max_num( self, *, - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, - use_thumbnail: Optional[bool] = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, + use_thumbnail: bool | None = None, ) -> tuple[int, int]: min_dynamic_patch = ( self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch @@ -308,12 +307,12 @@ def resolve_min_max_num( def resolve_target_ratios( self, *, - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, - use_thumbnail: Optional[bool] = None, - prior_aspect_ratio: Optional[tuple[int, int]] = None, - override_min_num: Optional[int] = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, + use_thumbnail: bool | None = None, + prior_aspect_ratio: tuple[int, int] | None = None, + override_min_num: int | None = None, ) -> list[tuple[int, int]]: min_num, max_num = self.resolve_min_max_num( min_dynamic_patch=min_dynamic_patch, @@ -335,7 +334,7 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - use_msac: Optional[bool] = None, + use_msac: bool | None = None, ) -> int: use_msac = self.use_msac if use_msac is None else use_msac @@ -385,9 +384,9 @@ def get_num_image_tokens( def _images_to_pixel_values_lst( self, images: list[Image.Image], - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, ) -> list[torch.Tensor]: use_msac = self.use_msac if len(images) == 1 else False @@ -425,8 +424,8 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - processor: Optional[H2OVLProcessor], - use_msac: Optional[bool] = None, + processor: H2OVLProcessor | None, + use_msac: bool | None = None, ) -> int: if processor is None: processor = self.get_hf_processor() @@ -493,11 +492,11 @@ def get_replacement_internvl(item_idx: int): def _cached_apply_hf_processor( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - mm_uuids: Optional[MultiModalUUIDDict] = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: # The processor logic is different for len(images) <= 1 vs > 1 # Since the processing cache assumes that the processor output is @@ -530,7 +529,7 @@ class H2OVLChatModel(InternVLChatModel): def _init_vision_model( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, *, is_mono: bool, prefix: str, diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index d33406b7be2b..901f29310872 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -26,7 +26,8 @@ import typing from collections.abc import Callable, Iterable -from typing import Any, Optional, Union +from itertools import islice +from typing import Any import regex as re import torch @@ -43,7 +44,7 @@ tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -101,7 +102,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, prefix: str = "", reduce_results: bool = True, @@ -143,11 +144,11 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, - cache_config: Optional[CacheConfig] = None, + cache_config: CacheConfig | None = None, prefix: str = "", layer_id: int = -1, ) -> None: @@ -226,7 +227,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_states: Optional[tuple[torch.Tensor]] = None, + kv_states: tuple[torch.Tensor] | None = None, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -255,11 +256,11 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, - cache_config: Optional[CacheConfig] = None, + cache_config: CacheConfig | None = None, prefix: str = "", layer_id: int = -1, ) -> None: @@ -337,7 +338,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_states: Optional[tuple[torch.Tensor]] = None, + kv_states: tuple[torch.Tensor] | None = None, ) -> torch.Tensor: assert kv_states is not None ori_k, v = kv_states # use last layer kv, @@ -364,7 +365,7 @@ class HunYuanSparseMoeBlock(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, layer_id: int = -1, prefix: str = "", enable_eplb: bool = False, @@ -414,19 +415,6 @@ def __init__( self.physical_expert_start + self.n_local_physical_experts ) - self.experts = FusedMoE( - num_experts=self.n_routed_experts, - top_k=top_k, - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - reduce_results=False, - renormalize=top_k > 1, - quant_config=quant_config, - prefix=f"{prefix}.experts", - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts, - ) - self.gate = ReplicatedLinear( config.hidden_size, config.num_experts, @@ -454,22 +442,34 @@ def __init__( else: self.shared_mlp = None + self.experts = SharedFusedMoE( + shared_experts=self.shared_mlp, + num_experts=self.n_routed_experts, + top_k=top_k, + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + reduce_results=False, + renormalize=top_k > 1, + quant_config=quant_config, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + ) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) - shared_output = None - if self.shared_mlp is not None: - shared_output = self.shared_mlp(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits ) - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output + if self.shared_mlp is not None: + final_hidden_states = final_hidden_states[0] + final_hidden_states[1] + if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) @@ -480,8 +480,8 @@ class HunYuanDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", layer_id: int = -1, enable_eplb: bool = False, @@ -577,8 +577,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - kv_states: Optional[tuple[torch.Tensor]] = None, + residual: torch.Tensor | None, + kv_states: tuple[torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -654,11 +654,11 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -672,8 +672,9 @@ def forward( cla_factor = _get_cla_factor(self.config) prev_kv_states = None - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for i, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer) + ): hidden_states, residual, kv_states = layer( positions, hidden_states, @@ -681,10 +682,7 @@ def forward( prev_kv_states, ) - if ( - getattr(self.config, "use_cla", False) - and (i - self.start_layer) % cla_factor == 0 - ): + if getattr(self.config, "use_cla", False) and i % cla_factor == 0: prev_kv_states = kv_states else: prev_kv_states = None @@ -725,7 +723,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: if _is_moe(self.config): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return FusedMoE.make_expert_params_mapping( + return SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", @@ -962,9 +960,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: model_output = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -973,7 +971,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits @@ -1009,7 +1007,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Set MoE hyperparameters self.expert_weights = [] self.num_expert_groups = 1 - self.moe_layers: list[FusedMoE] = [] + self.moe_layers: list[SharedFusedMoE] = [] example_layer = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index 611c14733c71..3d28ba951b94 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -6,7 +6,7 @@ from collections.abc import Iterable, Mapping, Sequence from functools import partial from itertools import accumulate -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal import numpy as np import torch @@ -115,13 +115,13 @@ class HCXVisionProcessingInfo(BaseProcessingInfo): def get_vision_encoder_info(self): return get_vision_encoder_info(self.get_hf_config()) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None, "video": None} def get_num_image_tokens( self, *, - vision_query_length: Union[int, list[int]], + vision_query_length: int | list[int], ) -> int: if isinstance(vision_query_length, int): return vision_query_length @@ -131,7 +131,7 @@ def get_num_image_tokens( def get_num_video_tokens( self, *, - vision_query_length: Union[int, list[int]], + vision_query_length: int | list[int], ) -> int: if isinstance(vision_query_length, int): return vision_query_length @@ -166,7 +166,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -346,7 +346,7 @@ def _build_hcxvision_hf_processor( info: HCXVisionProcessingInfo, dummy_inputs: BaseDummyInputsBuilder[HCXVisionProcessingInfo], *, - cache: Optional[BaseMultiModalProcessorCache] = None, + cache: BaseMultiModalProcessorCache | None = None, ) -> BaseMultiModalProcessor: if isinstance(info, HCXVisionProcessingInfo): return HCXVisionMultiModalProcessor( @@ -360,12 +360,12 @@ def _build_hcxvision_hf_processor( def init_vision_tower_for_hcxvision( vision_config, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, *, - use_nth_layer: Optional[int] = None, - require_post_norm: Optional[bool] = None, + use_nth_layer: int | None = None, + require_post_norm: bool | None = None, prefix: str = "", -) -> Union[CLIPVisionModel, SiglipVisionModel]: +) -> CLIPVisionModel | SiglipVisionModel: num_hidden_layers = vision_config.num_hidden_layers if not isinstance(use_nth_layer, int): pass @@ -473,8 +473,8 @@ def __init__( def forward( self, x: torch.Tensor, - num_queries_vis_abstractors: Optional[list[list[int]]] = None, - num_grids: Optional[list[int]] = None, + num_queries_vis_abstractors: list[list[int]] | None = None, + num_grids: list[int] | None = None, ) -> torch.Tensor: if self.prenorm is not None: x = self.prenorm(x) @@ -493,8 +493,8 @@ def forward( def _forward( self, x: torch.Tensor, - num_queries_vis_abstractors: Optional[list[list[int]]] = None, - num_grids: Optional[list[int]] = None, + num_queries_vis_abstractors: list[list[int]] | None = None, + num_grids: list[int] | None = None, ) -> torch.Tensor: # x: [B, L, dim] B, L, dim = x.shape @@ -515,8 +515,8 @@ def _forward( def _forward_adaptive_num_query( self, x: torch.Tensor, - num_queries_vis_abstractors: Optional[list[list[int]]] = None, - num_grids: Optional[list[int]] = None, + num_queries_vis_abstractors: list[list[int]] | None = None, + num_grids: list[int] | None = None, ) -> list[torch.Tensor]: # self.net is consisted by 3 layers (s1, sampler, s2) assert len(self.net) == 3 @@ -604,7 +604,7 @@ def __init__( *, vllm_config: VllmConfig, prefix: str = "", - **kwargs: Optional[Any], + **kwargs: Any | None, ) -> None: super().__init__() @@ -662,7 +662,7 @@ def __init__( # self.reduction = self._init_reduction_type(use_sum_loss) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return IMAGE_TOKEN if modality.startswith("video"): @@ -673,7 +673,7 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def _parse_and_validate_image_input( self, **kwargs: object, - ) -> Optional[HCXVisionImageInputs]: + ) -> HCXVisionImageInputs | None: pixel_values_images = kwargs.pop("pixel_values_images", None) if pixel_values_images is None: @@ -689,7 +689,7 @@ def _parse_and_validate_image_input( def _parse_and_validate_video_input( self, **kwargs: object, - ) -> Optional[HCXVisionVideoInputs]: + ) -> HCXVisionVideoInputs | None: pixel_values_videos = kwargs.pop("pixel_values_videos", None) if pixel_values_videos is None: @@ -749,12 +749,12 @@ def get_multimodal_embeddings( for modality in modalities: if modality == "images": image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) - multimodal_embeddings += vision_embeddings + image_embeddings = self._process_image_input(image_input) + multimodal_embeddings += tuple(image_embeddings) if modality == "videos": video_input = modalities["videos"] video_embeddings = self._process_video_input(video_input) - multimodal_embeddings += video_embeddings + multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings @@ -762,10 +762,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None @@ -946,7 +946,7 @@ def _prepare_multimodal_kwargs(self, **kwargs: object): def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights( @@ -1062,7 +1062,7 @@ def select_best_resolution(original_size: tuple, possible_resolutions: list) -> def get_anyres_image_grid_shape( image_size: tuple[int, int], - grid_pinpoints: Union[str, list[tuple[int, int]]], + grid_pinpoints: str | list[tuple[int, int]], patch_size: int, ) -> tuple[int, int]: possible_resolutions = ( diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 02c46a11a179..727c8ec0397c 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -19,7 +19,6 @@ """PyTorch Idefics2 model.""" from collections.abc import Iterable -from typing import Optional import torch from torch import nn @@ -77,7 +76,7 @@ def forward( self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor, - tgt_sizes: Optional[torch.IntTensor] = None, + tgt_sizes: torch.IntTensor | None = None, ) -> torch.Tensor: batch_size, _, max_im_h, max_im_w = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype @@ -124,7 +123,7 @@ class Idefics2VisionAttention(nn.Module): def __init__( self, config: Idefics2VisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ) -> None: @@ -185,7 +184,7 @@ class Idefics2VisionMLP(nn.Module): def __init__( self, config: Idefics2VisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ) -> None: @@ -220,7 +219,7 @@ class Idefics2EncoderLayer(nn.Module): def __init__( self, config: Idefics2Config, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ) -> None: @@ -275,9 +274,9 @@ class Idefics2Encoder(nn.Module): def __init__( self, config: Idefics2Config, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, - num_hidden_layers_override: Optional[int] = None, + num_hidden_layers_override: int | None = None, prefix: str = "", use_data_parallel: bool = False, ) -> None: @@ -326,9 +325,9 @@ class Idefics2VisionTransformer(nn.Module): def __init__( self, config: Idefics2VisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, - num_hidden_layers_override: Optional[int] = None, + num_hidden_layers_override: int | None = None, require_post_norm: bool = True, prefix: str = "", use_data_parallel: bool = False, @@ -370,8 +369,8 @@ def get_input_embeddings(self): def forward( self, pixel_values, - patch_attention_mask: Optional[torch.BoolTensor] = None, - tgt_sizes: Optional[torch.IntTensor] = None, + patch_attention_mask: torch.BoolTensor | None = None, + tgt_sizes: torch.IntTensor | None = None, ) -> torch.Tensor: hidden_states = self.embeddings( pixel_values=pixel_values, diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index effdbdc1ac38..06ca8c488634 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -18,7 +18,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal, TypeAlias import torch from torch import nn @@ -91,14 +91,14 @@ class Idefics3ImageEmbeddingInputs(TensorSchema): data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")] -ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs] +ImageInputs: TypeAlias = Idefics3ImagePixelInputs | Idefics3ImageEmbeddingInputs class Idefics3ProcessingInfo(BaseProcessingInfo): def get_hf_processor(self, **kwargs: object) -> Idefics3Processor: return self.ctx.get_hf_processor(Idefics3Processor, **kwargs) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def _resize_output_size( @@ -106,9 +106,9 @@ def _resize_output_size( *, height: int, width: int, - max_len: Optional[int] = None, + max_len: int | None = None, min_len: int = 1, - max_size: Optional[int] = None, + max_size: int | None = None, ) -> tuple[int, int]: # Set default value for max_len if not provided max_len = max(height, width) if max_len is None else max_len @@ -165,7 +165,7 @@ def _get_image_feature_grid_size( *, image_width: int, image_height: int, - processor: Optional[Idefics3Processor], + processor: Idefics3Processor | None, ) -> tuple[int, int]: if processor is None: processor = self.get_hf_processor() @@ -197,7 +197,7 @@ def get_num_patches( *, image_width: int, image_height: int, - processor: Optional[Idefics3Processor], + processor: Idefics3Processor | None, ) -> int: grid_w, grid_h = self._get_image_feature_grid_size( image_width=image_width, @@ -208,7 +208,7 @@ def get_num_patches( return grid_w * grid_h + 1 def _get_image_token( - self, processor: Optional[Idefics3Processor] + self, processor: Idefics3Processor | None ) -> tuple[str, str, str]: if processor is None: processor = self.get_hf_processor() @@ -223,7 +223,7 @@ def get_image_repl( *, image_width: int, image_height: int, - processor: Optional[Idefics3Processor], + processor: Idefics3Processor | None, ) -> str: if processor is None: processor = self.get_hf_processor() @@ -269,7 +269,7 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - processor: Optional[Idefics3Processor], + processor: Idefics3Processor | None, ) -> int: if processor is None: processor = self.get_hf_processor() @@ -305,7 +305,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) hf_processor = self.info.get_hf_processor() @@ -425,7 +425,7 @@ class Idefics3SimpleMLP(nn.Module): def __init__( self, config: Idefics3Config, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -448,7 +448,7 @@ class Idefics3Connector(nn.Module): def __init__( self, config: Idefics3Config, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -557,9 +557,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.text_model( input_ids, positions, @@ -590,7 +590,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLo } @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<image>" @@ -621,9 +621,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head.weight = self.model.text_model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.text_config.vocab_size) - def _parse_and_validate_image_input( - self, **kwargs: object - ) -> Optional[ImageInputs]: + def _parse_and_validate_image_input(self, **kwargs: object) -> ImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -663,7 +661,7 @@ def _process_image_pixels(self, inputs: Idefics3ImagePixelInputs) -> torch.Tenso def _process_image_input( self, image_input: ImageInputs, - ) -> Union[torch.Tensor, list[torch.Tensor]]: + ) -> torch.Tensor | list[torch.Tensor]: if image_input["type"] == "image_embeds": return image_input["data"] @@ -687,10 +685,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 38c9d5abb587..1bc5f5ae5419 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,15 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable, Mapping, MutableSequence +from collections.abc import Callable, Iterable, Mapping, MutableSequence from typing import ( TYPE_CHECKING, - Callable, ClassVar, Literal, - Optional, Protocol, - Union, + TypeAlias, overload, runtime_checkable, ) @@ -26,7 +24,7 @@ from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.utils import supports_kw +from vllm.utils.func_utils import supports_kw from .interfaces_base import VllmModel, is_pooling_model @@ -34,10 +32,14 @@ from vllm.config import VllmConfig from vllm.model_executor.models.utils import WeightsMapper from vllm.sequence import IntermediateTensors +else: + VllmConfig = object + WeightsMapper = object + IntermediateTensors = object logger = init_logger(__name__) -MultiModalEmbeddings = Union[list[Tensor], Tensor, tuple[Tensor, ...]] +MultiModalEmbeddings: TypeAlias = list[Tensor] | Tensor | tuple[Tensor, ...] """ The output embeddings must be one of the following formats: @@ -79,7 +81,7 @@ class SupportsMultiModal(Protocol): """ @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: """ Get the placeholder text for the `i`th `modality` item in the prompt. """ @@ -127,7 +129,7 @@ def _get_text_embeddings( input_ids: Tensor, get_input_embeddings: Callable[[Tensor], Tensor], *, - is_multimodal: Optional[Tensor], + is_multimodal: Tensor | None, handle_oov_mm_token: bool, ) -> Tensor: if handle_oov_mm_token and is_multimodal is not None: @@ -145,9 +147,9 @@ def _get_text_embeddings( def get_input_embeddings( self, input_ids: Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, *, - is_multimodal: Optional[Tensor] = None, + is_multimodal: Tensor | None = None, handle_oov_mm_token: bool = False, ) -> Tensor: """ @@ -236,16 +238,16 @@ def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]: ... def supports_multimodal( - model: Union[type[object], object], -) -> Union[TypeIs[type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]: + model: type[object] | object, +) -> TypeIs[type[SupportsMultiModal]] | TypeIs[SupportsMultiModal]: return getattr(model, "supports_multimodal", False) -def supports_multimodal_raw_input_only(model: Union[type[object], object]) -> bool: +def supports_multimodal_raw_input_only(model: type[object] | object) -> bool: return getattr(model, "supports_multimodal_raw_input_only", False) -def supports_multimodal_encoder_tp_data(model: Union[type[object], object]) -> bool: +def supports_multimodal_encoder_tp_data(model: type[object] | object) -> bool: return getattr(model, "supports_encoder_tp_data", False) @@ -260,8 +262,8 @@ def supports_multimodal_pruning(model: object) -> TypeIs[SupportsMultiModalPruni def supports_multimodal_pruning( - model: Union[type[object], object], -) -> Union[TypeIs[type[SupportsMultiModalPruning]], TypeIs[SupportsMultiModalPruning]]: + model: type[object] | object, +) -> TypeIs[type[SupportsMultiModalPruning]] | TypeIs[SupportsMultiModalPruning]: return getattr(model, "supports_multimodal_pruning", False) @@ -279,7 +281,7 @@ class SupportsScoreTemplate(Protocol): """ @classmethod - def get_score_template(cls, query: str, document: str) -> Optional[str]: + def get_score_template(cls, query: str, document: str) -> str | None: """ Generate a full prompt by populating the score template with query and document content. """ # noqa: E501 @@ -304,8 +306,8 @@ def supports_score_template(model: object) -> TypeIs[SupportsScoreTemplate]: ... def supports_score_template( - model: Union[type[object], object], -) -> Union[TypeIs[type[SupportsScoreTemplate]], TypeIs[SupportsScoreTemplate]]: + model: type[object] | object, +) -> TypeIs[type[SupportsScoreTemplate]] | TypeIs[SupportsScoreTemplate]: return getattr(model, "supports_score_template", False) @@ -325,7 +327,7 @@ class SupportsLoRA(Protocol): # are empty by default. embedding_modules: ClassVar[dict[str, str]] = {} embedding_padding_modules: ClassVar[list[str]] = [] - packed_modules_mapping: ClassVar[dict[str, list[str]]] = {} + packed_modules_mapping: dict[str, list[str]] = {} # We can't use runtime_checkable with ClassVar for issubclass checks @@ -348,8 +350,8 @@ def supports_lora(model: object) -> TypeIs[SupportsLoRA]: ... def supports_lora( - model: Union[type[object], object], -) -> Union[TypeIs[type[SupportsLoRA]], TypeIs[SupportsLoRA]]: + model: type[object] | object, +) -> TypeIs[type[SupportsLoRA]] | TypeIs[SupportsLoRA]: result = _supports_lora(model) if not result: @@ -379,7 +381,7 @@ def supports_lora( return result -def _supports_lora(model: Union[type[object], object]) -> bool: +def _supports_lora(model: type[object] | object) -> bool: if isinstance(model, type): return isinstance(model, _SupportsLoRAType) @@ -404,15 +406,15 @@ def make_empty_intermediate_tensors( batch_size: int, dtype: torch.dtype, device: torch.device, - ) -> "IntermediateTensors": + ) -> IntermediateTensors: """Called when PP rank > 0 for profiling purposes.""" ... def forward( self, *, - intermediate_tensors: Optional["IntermediateTensors"], - ) -> Union[Tensor, "IntermediateTensors"]: + intermediate_tensors: IntermediateTensors | None, + ) -> IntermediateTensors | None: """ Accept [`IntermediateTensors`][vllm.sequence.IntermediateTensors] when PP rank > 0. @@ -434,13 +436,13 @@ def make_empty_intermediate_tensors( batch_size: int, dtype: torch.dtype, device: torch.device, - ) -> "IntermediateTensors": ... + ) -> IntermediateTensors: ... def forward( self, *, - intermediate_tensors: Optional["IntermediateTensors"], - ) -> Union[Tensor, "IntermediateTensors"]: ... + intermediate_tensors: IntermediateTensors | None, + ) -> Tensor | IntermediateTensors: ... @overload @@ -452,8 +454,8 @@ def supports_pp(model: object) -> TypeIs[SupportsPP]: ... def supports_pp( - model: Union[type[object], object], -) -> Union[bool, TypeIs[type[SupportsPP]], TypeIs[SupportsPP]]: + model: type[object] | object, +) -> bool | TypeIs[type[SupportsPP]] | TypeIs[SupportsPP]: supports_attributes = _supports_pp_attributes(model) supports_inspect = _supports_pp_inspect(model) @@ -487,14 +489,14 @@ def supports_pp( return supports_attributes and supports_inspect -def _supports_pp_attributes(model: Union[type[object], object]) -> bool: +def _supports_pp_attributes(model: type[object] | object) -> bool: if isinstance(model, type): return isinstance(model, _SupportsPPType) return isinstance(model, SupportsPP) -def _supports_pp_inspect(model: Union[type[object], object]) -> bool: +def _supports_pp_inspect(model: type[object] | object) -> bool: model_forward = getattr(model, "forward", None) if not callable(model_forward): return False @@ -523,8 +525,8 @@ def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]: ... def has_inner_state( - model: Union[type[object], object], -) -> Union[TypeIs[type[HasInnerState]], TypeIs[HasInnerState]]: + model: type[object] | object, +) -> TypeIs[type[HasInnerState]] | TypeIs[HasInnerState]: return getattr(model, "has_inner_state", False) @@ -550,8 +552,8 @@ def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]: ... def is_attention_free( - model: Union[type[object], object], -) -> Union[TypeIs[type[IsAttentionFree]], TypeIs[IsAttentionFree]]: + model: type[object] | object, +) -> TypeIs[type[IsAttentionFree]] | TypeIs[IsAttentionFree]: return getattr(model, "is_attention_free", False) @@ -570,7 +572,7 @@ class IsHybrid(Protocol): @classmethod def get_mamba_state_shape_from_config( cls, - vllm_config: "VllmConfig", + vllm_config: VllmConfig, use_v1: bool = True, ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. @@ -596,8 +598,8 @@ def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]: ... def is_hybrid( - model: Union[type[object], object], -) -> Union[TypeIs[type[IsHybrid]], TypeIs[IsHybrid]]: + model: type[object] | object, +) -> TypeIs[type[IsHybrid]] | TypeIs[IsHybrid]: return getattr(model, "is_hybrid", False) @@ -671,7 +673,9 @@ def update_physical_experts_metadata( def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]: - return isinstance(model, MixtureOfExperts) + return ( + isinstance(model, MixtureOfExperts) and getattr(model, "num_moe_layers", 0) > 0 + ) @runtime_checkable @@ -688,8 +692,8 @@ def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]: ... def has_noops( - model: Union[type[object], object], -) -> Union[TypeIs[type[HasNoOps]], TypeIs[HasNoOps]]: + model: type[object] | object, +) -> TypeIs[type[HasNoOps]] | TypeIs[HasNoOps]: return getattr(model, "has_noops", False) @@ -711,23 +715,23 @@ def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: ... def _supports_cross_encoding( - model: Union[type[object], object], -) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: + model: type[object] | object, +) -> TypeIs[type[SupportsCrossEncoding]] | TypeIs[SupportsCrossEncoding]: return getattr(model, "supports_cross_encoding", False) def supports_cross_encoding( - model: Union[type[object], object], -) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: + model: type[object] | object, +) -> TypeIs[type[SupportsCrossEncoding]] | TypeIs[SupportsCrossEncoding]: return is_pooling_model(model) and _supports_cross_encoding(model) class SupportsQuant: """The interface required for all models that support quantization.""" - hf_to_vllm_mapper: ClassVar[Optional["WeightsMapper"]] = None - packed_modules_mapping: ClassVar[Optional[dict[str, list[str]]]] = None - quant_config: Optional[QuantizationConfig] = None + hf_to_vllm_mapper: ClassVar[WeightsMapper | None] = None + packed_modules_mapping: ClassVar[dict[str, list[str]] | None] = None + quant_config: QuantizationConfig | None = None def __new__(cls, *args, **kwargs) -> Self: instance = super().__new__(cls) @@ -749,7 +753,7 @@ def __new__(cls, *args, **kwargs) -> Self: return instance @staticmethod - def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]: + def _find_quant_config(*args, **kwargs) -> QuantizationConfig | None: """Find quant config passed through model constructor args""" from vllm.config import VllmConfig # avoid circular import @@ -797,10 +801,10 @@ def get_generation_prompt( audio: np.ndarray, stt_config: SpeechToTextConfig, model_config: ModelConfig, - language: Optional[str], + language: str | None, task_type: Literal["transcribe", "translate"], request_prompt: str, - to_language: Optional[str], + to_language: str | None, ) -> PromptType: """Get the prompt for the ASR model. The model has control over the construction, as long as it @@ -813,7 +817,7 @@ def get_other_languages(cls) -> Mapping[str, str]: return {k: v for k, v in LANGUAGES.items() if k not in cls.supported_languages} @classmethod - def validate_language(cls, language: Optional[str]) -> Optional[str]: + def validate_language(cls, language: str | None) -> str | None: """ Ensure the language specified in the transcription request is a valid ISO 639-1 language code. If the request language is @@ -850,7 +854,7 @@ def get_num_audio_tokens( audio_duration_s: float, stt_config: SpeechToTextConfig, model_config: ModelConfig, - ) -> Optional[int]: + ) -> int | None: """ Map from audio duration to number of audio tokens produced by the ASR model, without running a forward pass. @@ -870,32 +874,11 @@ def supports_transcription(model: object) -> TypeIs[SupportsTranscription]: ... def supports_transcription( - model: Union[type[object], object], -) -> Union[TypeIs[type[SupportsTranscription]], TypeIs[SupportsTranscription]]: + model: type[object] | object, +) -> TypeIs[type[SupportsTranscription]] | TypeIs[SupportsTranscription]: return getattr(model, "supports_transcription", False) -@runtime_checkable -class SupportsV0Only(Protocol): - """Models with this interface are not compatible with V1 vLLM.""" - - supports_v0_only: ClassVar[Literal[True]] = True - - -@overload -def supports_v0_only(model: type[object]) -> TypeIs[type[SupportsV0Only]]: ... - - -@overload -def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]: ... - - -def supports_v0_only( - model: Union[type[object], object], -) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]: - return getattr(model, "supports_v0_only", False) - - @runtime_checkable class SupportsEagle3(Protocol): """The interface required for models that support @@ -942,8 +925,8 @@ def supports_eagle3(model: object) -> TypeIs[SupportsEagle3]: ... def supports_eagle3( - model: Union[type[object], object], -) -> Union[TypeIs[type[SupportsEagle3]], TypeIs[SupportsEagle3]]: + model: type[object] | object, +) -> TypeIs[type[SupportsEagle3]] | TypeIs[SupportsEagle3]: return isinstance(model, SupportsEagle3) @@ -964,12 +947,12 @@ def get_mrope_input_positions( self, input_tokens: list[int], hf_config: PretrainedConfig, - image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], - video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], - second_per_grid_ts: Optional[list[float]] = None, + image_grid_thw: list[list[int]] | torch.Tensor | None, + video_grid_thw: list[list[int]] | torch.Tensor | None, + second_per_grid_ts: list[float] | None = None, context_len: int = 0, - seq_len: Optional[int] = None, - audio_feature_lengths: Optional[torch.Tensor] = None, + seq_len: int | None = None, + audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: """ @@ -1007,6 +990,6 @@ def supports_mrope(model: object) -> TypeIs[SupportsMRoPE]: ... def supports_mrope( - model: Union[type[object], object], -) -> Union[TypeIs[type[SupportsMRoPE]], TypeIs[SupportsMRoPE]]: + model: type[object] | object, +) -> TypeIs[type[SupportsMRoPE]] | TypeIs[SupportsMRoPE]: return isinstance(model, SupportsMRoPE) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index b697eb25b5cc..d87a65a47083 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -5,9 +5,7 @@ Any, ClassVar, Literal, - Optional, Protocol, - Union, overload, runtime_checkable, ) @@ -17,7 +15,7 @@ from typing_extensions import TypeIs, TypeVar from vllm.logger import init_logger -from vllm.utils import supports_kw +from vllm.utils.func_utils import supports_kw if TYPE_CHECKING: from vllm.config import VllmConfig @@ -63,12 +61,12 @@ def forward( ) -> T_co: ... -def _check_vllm_model_init(model: Union[type[object], object]) -> bool: +def _check_vllm_model_init(model: type[object] | object) -> bool: model_init = model.__init__ return supports_kw(model_init, "vllm_config") -def _check_vllm_model_get_input_embeddings(model: Union[type[object], object]) -> bool: +def _check_vllm_model_get_input_embeddings(model: type[object] | object) -> bool: model_get_input_embeddings = getattr(model, "get_input_embeddings", None) if not callable(model_get_input_embeddings): logger.warning( @@ -80,7 +78,7 @@ def _check_vllm_model_get_input_embeddings(model: Union[type[object], object]) - return True -def _check_vllm_model_forward(model: Union[type[object], object]) -> bool: +def _check_vllm_model_forward(model: type[object] | object) -> bool: model_forward = getattr(model, "forward", None) if not callable(model_forward): return False @@ -108,8 +106,8 @@ def is_vllm_model(model: object) -> TypeIs[VllmModel]: ... def is_vllm_model( - model: Union[type[object], object], -) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]: + model: type[object] | object, +) -> TypeIs[type[VllmModel]] | TypeIs[VllmModel]: return ( _check_vllm_model_init(model) and _check_vllm_model_get_input_embeddings(model) @@ -124,7 +122,7 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]): def compute_logits( self, hidden_states: T, - ) -> Optional[T]: + ) -> T | None: """Return `None` if TP rank > 0.""" ... @@ -140,10 +138,8 @@ def is_text_generation_model(model: object) -> TypeIs[VllmModelForTextGeneration def is_text_generation_model( - model: Union[type[object], object], -) -> Union[ - TypeIs[type[VllmModelForTextGeneration]], TypeIs[VllmModelForTextGeneration] -]: + model: type[object] | object, +) -> TypeIs[type[VllmModelForTextGeneration]] | TypeIs[VllmModelForTextGeneration]: if not is_vllm_model(model): return False @@ -190,8 +186,8 @@ def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: ... def is_pooling_model( - model: Union[type[object], object], -) -> Union[TypeIs[type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]: + model: type[object] | object, +) -> TypeIs[type[VllmModelForPooling]] | TypeIs[VllmModelForPooling]: if not is_vllm_model(model): return False @@ -211,5 +207,5 @@ def func(model: _T) -> _T: return func -def get_default_pooling_type(model: Union[type[object], object]) -> str: +def get_default_pooling_type(model: type[object] | object) -> str: return getattr(model, "default_pooling_type", "LAST") diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 9435ff0d26cf..03918127c6ae 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -9,7 +9,6 @@ # -------------------------------------------------------- from collections.abc import Iterable from functools import partial -from typing import Optional import torch import torch.nn as nn @@ -121,8 +120,8 @@ def get_input_embeddings(self): def forward( self, - pixel_values: Optional[torch.Tensor] = None, - pixel_embeds: Optional[torch.Tensor] = None, + pixel_values: torch.Tensor | None = None, + pixel_embeds: torch.Tensor | None = None, ) -> torch.FloatTensor: if pixel_values is None and pixel_embeds is None: raise ValueError("You have to specify pixel_values or pixel_embeds") @@ -144,7 +143,7 @@ class InternParallelAttention(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, num_dummy_heads: int = 0, prefix: str = "", @@ -240,7 +239,7 @@ class InternMLP(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ) -> None: @@ -277,7 +276,7 @@ class InternVisionEncoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, num_dummy_heads: int = 0, prefix: str = "", @@ -312,7 +311,7 @@ def __init__( def _init_attn( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, *, num_dummy_heads: int, prefix: str = "", @@ -350,9 +349,9 @@ class InternVisionEncoder(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, - num_hidden_layers_override: Optional[int] = None, + num_hidden_layers_override: int | None = None, num_dummy_heads: int = 0, prefix: str = "", use_data_parallel: bool = False, @@ -395,9 +394,9 @@ class InternVisionModel(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, - num_hidden_layers_override: Optional[int] = None, + num_hidden_layers_override: int | None = None, num_dummy_heads: int = 0, prefix: str = "", use_data_parallel: bool = False, @@ -422,8 +421,8 @@ def get_input_embeddings(self): def forward( self, - pixel_values: Optional[torch.Tensor] = None, - pixel_embeds: Optional[torch.Tensor] = None, + pixel_values: torch.Tensor | None = None, + pixel_embeds: torch.Tensor | None = None, ) -> torch.FloatTensor: if pixel_values is None and pixel_embeds is None: raise ValueError("You have to specify pixel_values or pixel_embeds") diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 128791541b3d..c5bbd5497a14 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -4,7 +4,7 @@ from collections.abc import Iterable from functools import partial from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -54,7 +54,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -92,10 +92,10 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -198,8 +198,8 @@ class InternLMDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -232,7 +232,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -291,9 +291,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -359,8 +359,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds @@ -370,7 +370,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.output, hidden_states) return logits @@ -444,16 +444,16 @@ def __init__( assert pooler_config is not None self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}, + {"token_classify": Pooler.for_token_classify(pooler_config)} ) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py index 5344ded280b2..6dc081e34157 100644 --- a/vllm/model_executor/models/internlm2_ve.py +++ b/vllm/model_executor/models/internlm2_ve.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -25,8 +24,8 @@ class InternLM2VEDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -66,8 +65,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - visual_token_mask: Optional[torch.Tensor] = None, + residual: torch.Tensor | None, + visual_token_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -107,10 +106,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - visual_token_mask: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + visual_token_mask: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index 06c7c8ccd0b5..1f251935a70a 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -7,7 +7,7 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal, TypeAlias import regex as re import torch @@ -111,12 +111,10 @@ class InternS1ImageEmbeddingInputs(TensorSchema): """ type: Literal["image_embeds"] = "image_embeds" - data: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], TensorShape("ni", "tifs", "hs") - ] + data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("ni", "tifs", "hs")] -InternS1ImageInputs = Union[InternS1ImagePixelInputs, InternS1ImageEmbeddingInputs] +InternS1ImageInputs: TypeAlias = InternS1ImagePixelInputs | InternS1ImageEmbeddingInputs class InternS1VideoPixelInputs(TensorSchema): @@ -143,12 +141,10 @@ class InternS1VideoEmbeddingInputs(TensorSchema): """ type: Literal["video_embeds"] = "video_embeds" - data: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], TensorShape("nv", "tvfs", "hs") - ] + data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("nv", "tvfs", "hs")] -InternS1VideoInputs = Union[InternS1VideoPixelInputs, InternS1VideoEmbeddingInputs] +InternS1VideoInputs: TypeAlias = InternS1VideoPixelInputs | InternS1VideoEmbeddingInputs def resolve_interns1_min_max_num( @@ -186,11 +182,14 @@ class InternS1ProcessingInfo(BaseProcessingInfo): def get_hf_processor(self, **kwargs: object) -> InternVLProcessor: hf_processor = self.ctx.get_hf_processor(InternVLProcessor, **kwargs) hf_processor.video_processor = cached_video_processor_from_config( - self.ctx.model_config, processor_cls=InternVLVideoProcessor, **kwargs + self.ctx.model_config, + processor_cls=InternVLVideoProcessor, + size=hf_processor.image_processor.size, + **kwargs, ) return hf_processor - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None, "video": None} def get_num_image_tokens( @@ -198,7 +197,7 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - processor: Optional["GotOcr2ImageProcessorFast"] = None, + processor: GotOcr2ImageProcessorFast | None = None, ) -> int: if processor is None: processor = self.get_hf_processor().image_processor @@ -213,7 +212,7 @@ def get_num_image_tokens( num_image_tokens = self.get_hf_processor().image_seq_length * num_image_patches return num_image_tokens - def resolve_target_ratios(self, use_thumbnail: Optional[bool] = None): + def resolve_target_ratios(self, use_thumbnail: bool | None = None): image_processor = self.get_hf_processor().image_processor min_dynamic_patch = image_processor.min_patches max_dynamic_patch = image_processor.max_patches @@ -298,7 +297,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: target_width, target_height = self.info.get_image_size_with_most_features() target_num_frames = self.info.get_num_frames_with_most_features( @@ -523,7 +522,7 @@ class InternS1ForConditionalGeneration( ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: # transformers InternVLProcessor uses <IMG_CONTEXT> as the separator # refer to https://github.com/huggingface/transformers/blob/f90de364c2484c7c325bbe05befdcf487bd75b63/src/transformers/models/internvl/processing_internvl.py#L116 if modality.startswith("image"): @@ -576,7 +575,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def _init_vision_model( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, *, prefix: str, ): @@ -620,7 +619,7 @@ def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[InternS1ImageInputs]: + ) -> InternS1ImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_num_patches = kwargs.pop("image_num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -635,8 +634,11 @@ def _parse_and_validate_image_input( ) image_token_id = kwargs["image_token_id"] - assert isinstance(image_token_id, torch.Tensor) - self.img_context_token_id = image_token_id.flatten().unique().item() + if isinstance(image_token_id, torch.Tensor): + image_token_id = image_token_id.flatten().unique().item() + + assert isinstance(image_token_id, int) + self.img_context_token_id = image_token_id if pixel_values is not None: h, w = self.config.vision_config.image_size @@ -654,7 +656,7 @@ def _parse_and_validate_image_input( def _parse_and_validate_video_input( self, **kwargs: object - ) -> Optional[InternS1VideoInputs]: + ) -> InternS1VideoInputs | None: pixel_values_flat_video = kwargs.pop("pixel_values_videos", None) video_num_patches = kwargs.pop("video_num_patches", None) video_embeds = kwargs.pop("video_embeds", None) @@ -669,8 +671,11 @@ def _parse_and_validate_video_input( ) video_token_id = kwargs["video_token_id"] - assert isinstance(video_token_id, torch.Tensor) - self.video_context_token_id = video_token_id.flatten().unique().item() + if isinstance(video_token_id, torch.Tensor): + video_token_id = video_token_id.flatten().unique().item() + + assert isinstance(video_token_id, int) + self.video_context_token_id = video_token_id if pixel_values_flat_video is not None: h, w = self.config.vision_config.image_size @@ -688,7 +693,7 @@ def _parse_and_validate_video_input( def _process_vision_input( self, - image_input: Union[InternS1ImageInputs, InternS1VideoInputs], + image_input: InternS1ImageInputs | InternS1VideoInputs, ) -> tuple[torch.Tensor, ...]: if ( image_input["type"] == "image_embeds" @@ -751,21 +756,21 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: for modality in modalities: if modality == "images": image_input = modalities["images"] - vision_embeddings = self._process_vision_input(image_input) - multimodal_embeddings += vision_embeddings + image_embeddings = self._process_vision_input(image_input) + multimodal_embeddings += tuple(image_embeddings) if modality == "videos": video_input = modalities["videos"] video_embeddings = self._process_vision_input(video_input) - multimodal_embeddings += video_embeddings + multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, *, - is_multimodal: Optional[torch.Tensor] = None, + is_multimodal: torch.Tensor | None = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: @@ -786,8 +791,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> IntermediateTensors: if intermediate_tensors is not None: @@ -807,7 +812,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/interns1_vit.py b/vllm/model_executor/models/interns1_vit.py index f5965bdf7c9c..507503d75046 100644 --- a/vllm/model_executor/models/interns1_vit.py +++ b/vllm/model_executor/models/interns1_vit.py @@ -8,7 +8,6 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn @@ -139,7 +138,7 @@ def interpolate_pos_encoding( def forward( self, pixel_values: torch.Tensor, - bool_masked_pos: Optional[torch.BoolTensor] = None, + bool_masked_pos: torch.BoolTensor | None = None, ) -> torch.Tensor: _, _, height, width = pixel_values.shape embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values) @@ -218,16 +217,15 @@ def __init__( self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale) def forward(self, x: torch.Tensor) -> torch.Tensor: - B, N, C = x.shape + """x shape: (B, N, C)""" q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x) if self.qk_normalization: - B_, N_, H_, D_ = q.shape - q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_) - k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_) + q = self.q_norm(q) + k = self.k_norm(k) # Use unified MultiHeadAttention with automatic backend selection x = self.attn(q, k, v) @@ -240,7 +238,7 @@ class InternS1VisionMLP(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -274,7 +272,7 @@ class InternS1VisionLayer(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, num_dummy_heads: int = 0, prefix: str = "", @@ -309,7 +307,7 @@ def __init__( def _init_attn( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, *, num_dummy_heads: int, prefix: str = "", @@ -337,9 +335,9 @@ class InternS1VisionEncoder(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, - num_hidden_layers_override: Optional[int] = None, + num_hidden_layers_override: int | None = None, num_dummy_heads: int = 0, prefix: str = "", ): @@ -376,9 +374,9 @@ class InternS1VisionModel(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, - num_hidden_layers_override: Optional[int] = None, + num_hidden_layers_override: int | None = None, num_dummy_heads: int = 0, prefix: str = "", ) -> None: @@ -404,8 +402,8 @@ def get_input_embeddings(self): def forward( self, - pixel_values: Optional[torch.Tensor] = None, - pixel_embeds: Optional[torch.Tensor] = None, + pixel_values: torch.Tensor | None = None, + pixel_embeds: torch.Tensor | None = None, ) -> torch.FloatTensor: if pixel_values is None and pixel_embeds is None: raise ValueError("You have to specify pixel_values or pixel_embeds") diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 3cd3807dd888..e2d2647f0177 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -10,7 +10,7 @@ import os from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, Optional, TypeVar, Union +from typing import Annotated, Any, Literal, TypeAlias, TypeVar import numpy.typing as npt import torch @@ -51,8 +51,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import set_default_torch_num_threads from vllm.utils.tensor_schema import TensorSchema, TensorShape +from vllm.utils.torch_utils import set_default_torch_num_threads from .interfaces import ( MultiModalEmbeddings, @@ -94,10 +94,10 @@ class InternVLImageEmbeddingInputs(TensorSchema): """ type: Literal["image_embeds"] - data: Annotated[Union[torch.Tensor, list[torch.Tensor]], TensorShape("n", "f", "h")] + data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")] -InternVLImageInputs = Union[InternVLImagePixelInputs, InternVLImageEmbeddingInputs] +InternVLImageInputs: TypeAlias = InternVLImagePixelInputs | InternVLImageEmbeddingInputs class InternVLVideoPixelInputs(TensorSchema): @@ -124,10 +124,10 @@ class InternVLVideoEmbeddingInputs(TensorSchema): """ type: Literal["video_embeds"] - data: Annotated[Union[torch.Tensor, list[torch.Tensor]], TensorShape("n", "f", "h")] + data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")] -InternVLVideoInputs = Union[InternVLVideoPixelInputs, InternVLVideoEmbeddingInputs] +InternVLVideoInputs: TypeAlias = InternVLVideoPixelInputs | InternVLVideoEmbeddingInputs # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B @@ -349,9 +349,9 @@ def __init__( config: PretrainedConfig, tokenizer: AnyTokenizer, *, - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, ) -> None: super().__init__() @@ -391,17 +391,17 @@ def image_token_id(self) -> int: def get_image_repl( self, feature_size: int, - num_patches: Optional[int], + num_patches: int | None, ) -> PromptUpdateDetails[str]: raise NotImplementedError def resolve_min_max_num( self, *, - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, - use_thumbnail: Optional[bool] = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, + use_thumbnail: bool | None = None, ) -> tuple[int, int]: min_dynamic_patch = ( self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch @@ -426,10 +426,10 @@ def resolve_min_max_num( def resolve_target_ratios( self, *, - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, - use_thumbnail: Optional[bool] = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, + use_thumbnail: bool | None = None, ) -> list[tuple[int, int]]: min_num, max_num = self.resolve_min_max_num( min_dynamic_patch=min_dynamic_patch, @@ -463,9 +463,9 @@ def get_num_image_tokens( def _images_to_pixel_values_lst( self, images: list[Image.Image], - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, ) -> list[torch.Tensor]: min_num, max_num = self.resolve_min_max_num( min_dynamic_patch=min_dynamic_patch, @@ -489,9 +489,9 @@ def _preprocess_image( self, text: list[str], images: list[Image.Image], - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, ) -> tuple[list[str], dict[str, torch.Tensor]]: if len(images) == 0: image_inputs = {} @@ -517,7 +517,7 @@ def _preprocess_image( text = [t.replace("<image>", image_repl.full, 1) for t in text] return text, image_inputs - def _make_batch_input(self, input_item: Optional[Union[Any, list[Any]]] = None): + def _make_batch_input(self, input_item: Any | list[Any] | None = None): if input_item is None: input_item = [] if not isinstance(input_item, list): @@ -526,12 +526,12 @@ def _make_batch_input(self, input_item: Optional[Union[Any, list[Any]]] = None): def __call__( self, - text: Optional[Union[str, list[str]]] = None, - images: Optional[Union[Image.Image, list[Image.Image]]] = None, - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, - return_tensors: Optional[Union[str, TensorType]] = None, + text: str | list[str] | None = None, + images: Image.Image | list[Image.Image] | None = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, + return_tensors: str | TensorType | None = None, ) -> BatchFeature: text, images = [self._make_batch_input(x) for x in (text, images)] @@ -563,10 +563,10 @@ def __init__( config: PretrainedConfig, tokenizer: AnyTokenizer, *, - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, - video_token: Optional[str] = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, + video_token: str | None = None, ) -> None: super().__init__( config=config, @@ -583,7 +583,7 @@ def image_token_id(self) -> int: return self.tokenizer.get_vocab()[IMG_CONTEXT] @property - def video_token_id(self) -> Optional[int]: + def video_token_id(self) -> int | None: if self.video_token is None: return None return self.tokenizer.get_vocab().get(self.video_token, None) @@ -595,7 +595,7 @@ def supports_video(self) -> bool: def _videos_to_pixel_values_lst( self, videos: list[npt.NDArray], - dynamic_image_size: Optional[bool] = None, + dynamic_image_size: bool | None = None, ) -> list[torch.Tensor]: min_num, max_num = self.resolve_min_max_num( min_dynamic_patch=1, @@ -619,7 +619,7 @@ def _preprocess_video( self, text: list[str], videos: list[npt.NDArray], - dynamic_image_size: Optional[bool] = None, + dynamic_image_size: bool | None = None, ): if len(videos) == 0 or not self.supports_video: video_inputs = {} @@ -646,13 +646,13 @@ def _preprocess_video( def __call__( self, - text: Optional[Union[str, list[str]]] = None, - images: Optional[Union[Image.Image, list[Image.Image]]] = None, - videos: Optional[Union[npt.NDArray, list[npt.NDArray]]] = None, - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, - return_tensors: Optional[Union[str, TensorType]] = None, + text: str | list[str] | None = None, + images: Image.Image | list[Image.Image] | None = None, + videos: npt.NDArray | list[npt.NDArray] | None = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, + return_tensors: str | TensorType | None = None, ) -> BatchFeature: text, images, videos = [ self._make_batch_input(x) for x in (text, images, videos) @@ -681,7 +681,7 @@ def __call__( def get_image_repl( self, feature_size: int, - num_patches: Optional[int], + num_patches: int | None, ) -> PromptUpdateDetails[str]: repl_features = IMG_CONTEXT * feature_size repl_full = IMG_START + repl_features + IMG_END @@ -691,7 +691,7 @@ def get_image_repl( def get_video_repl( self, feature_size: int, - num_patches: Optional[int] = None, + num_patches: int | None = None, video_context_token: str = IMG_CONTEXT, ) -> PromptUpdateDetails[str]: repl_features = video_context_token * self.num_image_token @@ -711,7 +711,7 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo): def get_hf_processor(self, **kwargs: object) -> BaseInternVLProcessor: raise NotImplementedError - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_num_image_tokens( @@ -719,7 +719,7 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - processor: Optional[BaseInternVLProcessor], + processor: BaseInternVLProcessor | None, ) -> int: if processor is None: processor = self.get_hf_processor() @@ -779,7 +779,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) @@ -901,7 +901,7 @@ def get_supported_mm_limits(self): video_limit = {"video": None} if self.supports_video else {} return {**super().get_supported_mm_limits(), **video_limit} - def get_video_token(self) -> Optional[str]: + def get_video_token(self) -> str | None: text_model_type = self.get_hf_config().get_text_config().model_type video_token_map = { "qwen2": "<|video_pad|>", @@ -951,7 +951,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: dummy_image = super().get_dummy_mm_data( seq_len=seq_len, mm_counts=mm_counts, mm_options=mm_options @@ -1079,7 +1079,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA) supports_encoder_tp_data = True @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<image>" if modality.startswith("video"): @@ -1149,7 +1149,7 @@ def _patch_quant_config( def _init_vision_model( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, *, is_mono: bool, prefix: str, @@ -1217,7 +1217,7 @@ def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[InternVLImageInputs]: + ) -> InternVLImageInputs | None: pixel_values_flat = kwargs.pop("pixel_values_flat", None) image_num_patches = kwargs.pop("image_num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -1232,8 +1232,11 @@ def _parse_and_validate_image_input( ) image_token_id = kwargs["image_token_id"] - assert isinstance(image_token_id, torch.Tensor) - self.img_context_token_id = image_token_id.flatten().unique().item() + if isinstance(image_token_id, torch.Tensor): + image_token_id = image_token_id.flatten().unique().item() + + assert isinstance(image_token_id, int) + self.img_context_token_id = image_token_id if pixel_values_flat is not None: expected_h = expected_w = self.config.vision_config.image_size @@ -1250,7 +1253,7 @@ def _parse_and_validate_image_input( def _parse_and_validate_video_input( self, **kwargs: object - ) -> Optional[InternVLVideoPixelInputs]: + ) -> InternVLVideoPixelInputs | None: pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None) video_num_patches = kwargs.pop("video_num_patches", None) video_embeds = kwargs.pop("image_embeds", None) @@ -1265,8 +1268,11 @@ def _parse_and_validate_video_input( ) video_token_id = kwargs["video_token_id"] - assert isinstance(video_token_id, torch.Tensor) - self.video_context_token_id = video_token_id.flatten().unique().item() + if isinstance(video_token_id, torch.Tensor): + video_token_id = video_token_id.flatten().unique().item() + + assert isinstance(video_token_id, int) + self.video_context_token_id = video_token_id if pixel_values_flat_video is not None: expected_h = expected_w = self.config.vision_config.image_size @@ -1283,7 +1289,7 @@ def _parse_and_validate_video_input( def _process_vision_input( self, - image_input: Union[InternVLImageInputs, InternVLVideoInputs], + image_input: InternVLImageInputs | InternVLVideoInputs, ) -> tuple[torch.Tensor, ...]: if ( image_input["type"] == "image_embeds" @@ -1352,21 +1358,21 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: for modality in modalities: if modality == "images": image_input = modalities["images"] - vision_embeddings = self._process_vision_input(image_input) - multimodal_embeddings += vision_embeddings + image_embeddings = self._process_vision_input(image_input) + multimodal_embeddings += tuple(image_embeddings) if modality == "videos": video_input = modalities["videos"] video_embeddings = self._process_vision_input(video_input) - multimodal_embeddings += video_embeddings + multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, *, - is_multimodal: Optional[torch.Tensor] = None, + is_multimodal: torch.Tensor | None = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: @@ -1387,8 +1393,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> IntermediateTensors: if intermediate_tensors is not None: @@ -1413,7 +1419,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index d788ed7ec2af..1daaed80b144 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -24,7 +24,6 @@ import math from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -86,8 +85,8 @@ class JAISAttention(nn.Module): def __init__( self, config: JAISConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -147,7 +146,7 @@ def __init__( self, intermediate_size: int, config: JAISConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ): super().__init__() hidden_size = config.hidden_size @@ -194,8 +193,8 @@ class JAISBlock(nn.Module): def __init__( self, config: JAISConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -277,9 +276,9 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[IntermediateTensors, torch.Tensor]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> IntermediateTensors | torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is None: inputs_embeds = self.get_input_embeddings(input_ids) @@ -341,9 +340,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[IntermediateTensors, torch.Tensor]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> IntermediateTensors | torch.Tensor: hidden_states = self.transformer( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -352,7 +351,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 0371458f5578..f8a87cf6965f 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -4,7 +4,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional import torch from torch import nn @@ -54,11 +53,11 @@ class JambaMoE(nn.Module): def __init__( self, config: JambaConfig, - num_experts: Optional[int] = None, - top_k: Optional[int] = None, - params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, + num_experts: int | None = None, + top_k: int | None = None, + params_dtype: torch.dtype | None = None, + tp_size: int | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -111,10 +110,10 @@ def __init__( self, config: JambaConfig, layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - is_lora_enabled: Optional[bool] = False, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + is_lora_enabled: bool | None = False, prefix: str = "", **kwargs, ) -> None: @@ -159,7 +158,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, **kwargs, ): if residual is None: @@ -181,9 +180,9 @@ def __init__( self, config: JambaConfig, layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", **kwargs, ) -> None: @@ -266,7 +265,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, **kwargs, ): if residual is None: @@ -348,8 +347,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -523,8 +522,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, ): hidden_states = self.model( @@ -568,7 +567,7 @@ def get_mamba_state_shape_from_config( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits @@ -605,10 +604,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_classify": Pooler.for_token_classify( + pooler_config, classifier=self.score + ), "classify": Pooler.for_classify( - pooler_config, - classifier=self.score, + pooler_config, classifier=self.score, act_fn="classify" + ), + "score": Pooler.for_classify( + pooler_config, classifier=self.score, act_fn="score" ), } ) diff --git a/vllm/model_executor/models/jina_vl.py b/vllm/model_executor/models/jina_vl.py index 9711eeeeec33..05a40837954d 100644 --- a/vllm/model_executor/models/jina_vl.py +++ b/vllm/model_executor/models/jina_vl.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping -from typing import Optional import torch import torch.nn as nn @@ -98,21 +97,27 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.score = JinaVLScorer(vllm_config.model_config) self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), - "classify": Pooler.for_classify(pooler_config, classifier=self.score), - "score": Pooler.for_classify(pooler_config, classifier=self.score), + "token_classify": Pooler.for_token_classify( + pooler_config, classifier=self.score + ), + "classify": Pooler.for_classify( + pooler_config, classifier=self.score, act_fn="classify" + ), + "score": Pooler.for_classify( + pooler_config, classifier=self.score, act_fn="score" + ), } ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|vision_start|><|image_pad|><|vision_end|>" raise ValueError("Only image modality is supported") @classmethod - def get_score_template(cls, query: str, document: str) -> Optional[str]: + def get_score_template(cls, query: str, document: str) -> str | None: return f"**Document**:\n{document}\n**Query**:\n{query}" @classmethod @@ -124,8 +129,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor: hidden_states = super().forward( diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 7ccbc81431f6..acfd51a6d0cc 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -4,7 +4,7 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import Annotated, Any, Literal, Optional, TypeVar, Union +from typing import Annotated, Any, Literal, TypeAlias, TypeVar import numpy as np import torch @@ -153,7 +153,7 @@ class KeyeImageEmbeddingInputs(TensorSchema): image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] -KeyeImageInputs = Union[KeyeImagePixelInputs, KeyeImageEmbeddingInputs] +KeyeImageInputs: TypeAlias = KeyeImagePixelInputs | KeyeImageEmbeddingInputs class KeyeVideoPixelInputs(TensorSchema): @@ -188,7 +188,7 @@ class KeyeVideoEmbeddingInputs(TensorSchema): video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] -KeyeVideoInputs = Union[KeyeVideoPixelInputs, KeyeVideoEmbeddingInputs] +KeyeVideoInputs: TypeAlias = KeyeVideoPixelInputs | KeyeVideoEmbeddingInputs class KeyeVisionEmbeddings(nn.Module): @@ -278,15 +278,9 @@ def fetch_position_embedding_lfu_cache(self, embeddings, h, w, max_cache: int = def forward( self, pixel_values: torch.FloatTensor, - position_ids: Optional[torch.Tensor] = None, - image_grid_thw: Optional[ - list[ - Union[ - tuple[int, int, int], - list[tuple[int, int, int]], - ] - ] - ] = None, + position_ids: torch.Tensor | None = None, + image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] + | None = None, interpolate_pos_encoding=False, ) -> torch.Tensor: if pixel_values.dim() == 4: @@ -357,8 +351,9 @@ class KeyeSiglipAttention(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", + attn_backend_override: _Backend | None = None, ): super().__init__() self.config = config @@ -398,7 +393,9 @@ def __init__( # Detect attention implementation. self.attn_backend = get_vit_attn_backend( - head_size=self.head_dim, dtype=torch.get_default_dtype() + head_size=self.head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) self.use_upstream_fa = False @@ -416,10 +413,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - cu_seqlens: Optional[list[torch.Tensor]] = None, - rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = False, + cu_seqlens: list[torch.Tensor] | None = None, + rope_emb: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split( @@ -524,9 +521,10 @@ def forward(self, seqlen: int) -> torch.Tensor: class KeyeSiglipEncoderLayer(nn.Module): def __init__( self, - config: Union[PretrainedConfig], - quant_config: Optional[QuantizationConfig] = None, + config: PretrainedConfig, + quant_config: QuantizationConfig | None = None, prefix: str = "", + attn_backend_override: _Backend | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -535,6 +533,7 @@ def __init__( config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + attn_backend_override=attn_backend_override, ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( @@ -547,9 +546,9 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, - output_attentions: Optional[bool] = False, - cu_seqlens: Optional[list[torch.Tensor]] = None, - rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: bool | None = False, + cu_seqlens: list[torch.Tensor] | None = None, + rope_emb: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> tuple[torch.FloatTensor]: residual = hidden_states @@ -577,8 +576,9 @@ class KeyeSiglipEncoder(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", + attn_backend_override: _Backend | None = None, ): super().__init__() self.config = config @@ -591,6 +591,7 @@ def __init__( config, quant_config=quant_config, prefix=f"{prefix}.layers.{layer_idx}", + attn_backend_override=attn_backend_override, ) for layer_idx in range(config.num_hidden_layers) ] @@ -610,22 +611,16 @@ def flatten_list(image_grid_thw): def forward( self, inputs_embeds, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - cu_seqlens: Optional[list[torch.Tensor]] = None, - image_grid_thw: Optional[ - list[ - Union[ - tuple[int, int, int], - list[tuple[int, int, int]], - ] - ] - ] = None, - height_position_ids: Optional[torch.Tensor] = None, - width_position_ids: Optional[torch.Tensor] = None, - use_rope: Optional[bool] = False, - window_size: Optional[bool] = -1, + attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cu_seqlens: list[torch.Tensor] | None = None, + image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] + | None = None, + height_position_ids: torch.Tensor | None = None, + width_position_ids: torch.Tensor | None = None, + use_rope: bool | None = False, + window_size: bool | None = -1, vision_or_text: str = "vision", ) -> BaseModelOutput: device = inputs_embeds.device @@ -676,8 +671,9 @@ class KeyeSiglipVisionTransformer(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", + attn_backend_override: _Backend | None = None, ): super().__init__() self.config = config @@ -688,35 +684,30 @@ def __init__( config, quant_config=quant_config, prefix=f"{prefix}.encoder", + attn_backend_override=attn_backend_override, ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward( self, pixel_values, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: Optional[bool] = False, - attention_mask: Optional[torch.Tensor] = None, - sample_indices: Optional[torch.Tensor] = None, - image_indices: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - height_position_ids: Optional[torch.Tensor] = None, - width_position_ids: Optional[torch.Tensor] = None, - cu_seqlens: Optional[list[torch.Tensor]] = None, - padding_mask: Optional[torch.Tensor] = None, - vision_return_embed_list: Optional[bool] = False, - image_grid_thw: Optional[ - list[ - Union[ - tuple[int, int, int], - list[tuple[int, int, int]], - ] - ] - ] = None, - return_pooler_output: Optional[bool] = True, - use_rope: Optional[bool] = False, - window_size: Optional[bool] = -1, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + interpolate_pos_encoding: bool | None = False, + attention_mask: torch.Tensor | None = None, + sample_indices: torch.Tensor | None = None, + image_indices: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + height_position_ids: torch.Tensor | None = None, + width_position_ids: torch.Tensor | None = None, + cu_seqlens: list[torch.Tensor] | None = None, + padding_mask: torch.Tensor | None = None, + vision_return_embed_list: bool | None = False, + image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] + | None = None, + return_pooler_output: bool | None = True, + use_rope: bool | None = False, + window_size: bool | None = -1, ) -> BaseModelOutputWithPooling: hidden_states = self.embeddings( pixel_values, @@ -763,8 +754,9 @@ class KeyeSiglipVisionModel(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", + attn_backend_override: _Backend | None = None, ): super().__init__() @@ -772,6 +764,7 @@ def __init__( config, quant_config=quant_config, prefix=f"{prefix}.vision_model", + attn_backend_override=attn_backend_override, ) self.quant_config = quant_config @@ -789,24 +782,18 @@ def get_input_embeddings(self) -> nn.Module: def forward( self, pixel_values, - sample_indices: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, + sample_indices: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, interpolate_pos_encoding: bool = False, - position_ids: Optional[torch.Tensor] = None, - vision_return_embed_list: Optional[bool] = False, - image_grid_thw: Optional[ - list[ - Union[ - tuple[int, int, int], - list[tuple[int, int, int]], - ] - ] - ] = None, - cu_seqlens: Optional[list[torch.Tensor]] = None, - return_pooler_output: Optional[bool] = True, - use_rope: Optional[bool] = False, - window_size: Optional[bool] = -1, + position_ids: torch.Tensor | None = None, + vision_return_embed_list: bool | None = False, + image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] + | None = None, + cu_seqlens: list[torch.Tensor] | None = None, + return_pooler_output: bool | None = True, + use_rope: bool | None = False, + window_size: bool | None = -1, ) -> BaseModelOutputWithPooling: return self.vision_model( pixel_values=pixel_values, @@ -893,7 +880,7 @@ def __init__( self, text_config: PretrainedConfig, vision_config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -927,9 +914,9 @@ def __init__( def forward( self, - image_features: Union[torch.Tensor, list[torch.Tensor]], + image_features: torch.Tensor | list[torch.Tensor], image_grid_thw: list[tuple[int, int, int]], - ) -> Union[torch.Tensor, list[torch.Tensor]]: + ) -> torch.Tensor | list[torch.Tensor]: m1, m2 = self.merge_kernel_size if isinstance(image_features, (list, tuple)): processed_features = list() @@ -988,7 +975,7 @@ def _keye_field_config( class KeyeMultiModalDataParser(MultiModalDataParser): def _parse_image_data( self, - data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + data: dict[str, torch.Tensor] | ModalityData[ImageItem], ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): return DictEmbeddingItems( @@ -1005,7 +992,7 @@ def _parse_image_data( def _parse_video_data( self, - data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], + data: dict[str, torch.Tensor] | ModalityData[VideoItem], ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): return DictEmbeddingItems( @@ -1033,7 +1020,7 @@ def get_image_processor(self, **kwargs: object): def get_supported_mm_limits( self, - ) -> Mapping[str, Optional[int]]: + ) -> Mapping[str, int | None]: return {"image": None, "video": None} def get_mm_max_tokens_per_item( @@ -1200,7 +1187,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -1303,7 +1290,7 @@ class BaseKeyeModule(nn.Module): ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|vision_start|><|image_pad|><|vision_end|>" if modality.startswith("video"): @@ -1320,10 +1307,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.multimodal_config = multimodal_config + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.visual = KeyeSiglipVisionModel( config.vision_config, quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), + attn_backend_override=attn_backend_override, ) self.mlp_AR = self._build_projector( @@ -1348,7 +1341,7 @@ def _build_projector( self, text_config: PretrainedConfig, vision_config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: raise ValueError("Need projector") @@ -1403,8 +1396,8 @@ def _process_video_embeds( self, video_type: Literal["video_embeds", "pixel_values_videos"], video_grid_thw: list[torch.Tensor], - pixel_values_videos: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, list[torch.Tensor]]: + pixel_values_videos: torch.Tensor | None = None, + ) -> torch.Tensor | list[torch.Tensor]: siglip_position_ids = list() video_grid_hws = list() sample_indices = list() @@ -1473,7 +1466,7 @@ def get_language_model(self) -> torch.nn.Module: def get_multimodal_embeddings( self, **kwargs: object - ) -> Optional[MultiModalEmbeddings]: + ) -> MultiModalEmbeddings | None: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return None @@ -1483,22 +1476,22 @@ def get_multimodal_embeddings( for modality in modalities: if modality == "images": image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) - multimodal_embeddings += vision_embeddings + image_embeddings = self._process_image_input(image_input) + multimodal_embeddings += tuple(image_embeddings) if modality == "videos": video_input = modalities["videos"] video_embeddings = self._process_video_input(video_input) - multimodal_embeddings += video_embeddings + multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: """Run forward pass for Keye-VL. Args: @@ -1527,7 +1520,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -1555,14 +1548,14 @@ def _build_projector( self, text_config: PretrainedConfig, vision_config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: return Projector(text_config, vision_config, quant_config, prefix) def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[KeyeImageInputs]: + ) -> KeyeImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1586,7 +1579,7 @@ def _parse_and_validate_image_input( def _parse_and_validate_video_input( self, **kwargs: object - ) -> Optional[KeyeVideoInputs]: + ) -> KeyeVideoInputs | None: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py index 578436fcad21..13e5b2d5f157 100644 --- a/vllm/model_executor/models/keye_vl1_5.py +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -3,7 +3,7 @@ import itertools from collections.abc import Mapping, Sequence from functools import partial -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal, TypeAlias import numpy as np import torch @@ -38,7 +38,7 @@ ) from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP +from .interfaces import SupportsLoRA, SupportsMRoPE, SupportsMultiModal, SupportsPP from .keye import ( BaseKeyeModule, BaseMultiModalProcessor, @@ -73,7 +73,7 @@ def split_thw(grid_thw: torch.Tensor) -> torch.Tensor: def get_num_patches( - grid_thw: torch.Tensor, num_frames: Union[list[int], torch.Tensor] + grid_thw: torch.Tensor, num_frames: list[int] | torch.Tensor ) -> list[int]: """ Return num_patches per video. @@ -153,7 +153,9 @@ class KeyeVL1_5ImageEmbeddingInputs(TensorSchema): image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] -KeyeVL1_5ImageInputs = Union[KeyeVL1_5ImagePixelInputs, KeyeVL1_5ImageEmbeddingInputs] +KeyeVL1_5ImageInputs: TypeAlias = ( + KeyeVL1_5ImagePixelInputs | KeyeVL1_5ImageEmbeddingInputs +) class KeyeVL1_5VideoPixelInputs(TensorSchema): @@ -191,7 +193,9 @@ class KeyeVL1_5VideoEmbeddingInputs(TensorSchema): num_frames: torch.Tensor -KeyeVL1_5VideoInputs = Union[KeyeVL1_5VideoPixelInputs, KeyeVL1_5VideoEmbeddingInputs] +KeyeVL1_5VideoInputs: TypeAlias = ( + KeyeVL1_5VideoPixelInputs | KeyeVL1_5VideoEmbeddingInputs +) class KeyeVL1_5Projector(nn.Module): @@ -199,7 +203,7 @@ def __init__( self, text_config: PretrainedConfig, vision_config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -233,9 +237,9 @@ def __init__( def forward( self, - image_features: Union[torch.Tensor, tuple[torch.Tensor], list[torch.Tensor]], + image_features: torch.Tensor | tuple[torch.Tensor] | list[torch.Tensor], image_grid_thw: list[tuple[int, int, int]], - ) -> Union[torch.Tensor, list[torch.Tensor]]: + ) -> torch.Tensor | list[torch.Tensor]: m1, m2 = self.merge_kernel_size if isinstance(image_features, (list, tuple)): processed_features = list() @@ -275,7 +279,7 @@ def get_max_frame_per_video(self) -> int: def get_supported_mm_limits( self, - ) -> Mapping[str, Optional[int]]: + ) -> Mapping[str, int | None]: return {"image": None, "video": 1} @@ -327,7 +331,7 @@ def _keye_field_config( class KeyeVL1_5MultiModalDataParser(MultiModalDataParser): def _parse_image_data( self, - data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + data: dict[str, torch.Tensor] | ModalityData[ImageItem], ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): return DictEmbeddingItems( @@ -344,7 +348,7 @@ def _parse_image_data( def _parse_video_data( self, - data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], + data: dict[str, torch.Tensor] | ModalityData[VideoItem], ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): return DictEmbeddingItems( @@ -493,13 +497,13 @@ class KeyeVL1_5DummyInputsBuilder( dummy_inputs=KeyeVL1_5DummyInputsBuilder, ) class KeyeVL1_5ForConditionalGeneration( - BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP + BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): def _build_projector( self, text_config: PretrainedConfig, vision_config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: return KeyeVL1_5Projector(text_config, vision_config, quant_config, prefix) @@ -511,7 +515,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[KeyeVL1_5ImageInputs]: + ) -> KeyeVL1_5ImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -535,7 +539,7 @@ def _parse_and_validate_image_input( def _parse_and_validate_video_input( self, **kwargs: object - ) -> Optional[KeyeVL1_5VideoInputs]: + ) -> KeyeVL1_5VideoInputs | None: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -589,3 +593,142 @@ def _process_video_input( end = patch_cu_seqlens[idx + 1] new_video_embeds.append(video_embeds[start:end]) return tuple(new_video_embeds) + + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, + context_len: int = 0, + seq_len: int | None = None, + second_per_grid_ts: list[float] | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0: + video_grid_thw = video_grid_thw[0] + """Get mrope input positions and delta value (Keye series).""" + + def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: + """ + Split grid_thw along the t dimension. + + Args: + grid_thw: shape [N, 3] tensor or nested list of [t, h, w]. + + Returns: + List of [1, h, w] rows, repeated t times for each original row. + """ + + if isinstance(grid_thw, list): + grid_thw = torch.tensor(grid_thw, dtype=torch.long) + + if grid_thw.numel() == 0: + return [] + + t, hw = grid_thw[:, 0], grid_thw[:, 1:] + ones = torch.ones_like(hw[:, :1]) # [N,1] + out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0) + return out.tolist() + + video_grid_thw = split_thw(video_grid_thw) + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + + image_nums = len(image_grid_thw) + frame_nums = len(video_grid_thw) + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_frames = image_nums, frame_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + frame_nums): + if remain_images > 0: + try: + ed_image = input_tokens.index(image_token_id, st) + except ValueError: + ed_image = len(input_tokens) + 1 + else: + ed_image = len(input_tokens) + 1 + if remain_frames > 0: + try: + ed_video = input_tokens.index(video_token_id, st) + except ValueError: + ed_video = len(input_tokens) + 1 + else: + ed_video = len(input_tokens) + 1 + + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_frames -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + ) + .long() + .flatten() + ) + + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index f7381e6b6b93..c2630fa6ac2b 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -46,7 +46,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal import torch from torch import nn @@ -153,7 +153,7 @@ class KimiVLImagePixelInputs(TensorSchema): type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("np", 3, "ps", "ps"), ] @@ -169,7 +169,7 @@ class KimiVLProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(KimiVLConfig) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_num_image_tokens( @@ -227,7 +227,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) @@ -305,7 +305,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): supports_encoder_tp_data = True @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>" @@ -370,7 +370,7 @@ def __init__( def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[KimiVLImageInputs]: + ) -> KimiVLImageInputs | None: # image input type must be pixel values now pixel_values = kwargs.pop("pixel_values", None) image_grid_hws = kwargs.pop("image_grid_hws", None) @@ -411,7 +411,7 @@ def _process_image_input(self, image_input: KimiVLImageInputs) -> torch.Tensor: def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, **kwargs: object) -> Optional[NestedTensors]: + def get_multimodal_embeddings(self, **kwargs: object) -> NestedTensors | None: # Validate the multimodal input keyword arguments image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: @@ -425,8 +425,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> IntermediateTensors: if intermediate_tensors is not None: @@ -570,7 +570,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def get_spec_layer_idx_from_weight_name( config: DeepseekV2Config, weight_name: str -) -> Optional[int]: +) -> int | None: if hasattr(config, "num_nextn_predict_layers") and ( config.num_nextn_predict_layers > 0 ): diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index 425c93687760..5684b9a89125 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable from itertools import islice -from typing import Any, Optional +from typing import Any import torch import torch.nn as nn @@ -54,8 +54,8 @@ def __init__( ff_dim: int, multiple_of: int, auto_adjust_ff_dim: bool, - ffn_dim_multiplier: Optional[float], - quant_config: Optional[QuantizationConfig] = None, + ffn_dim_multiplier: float | None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -98,10 +98,10 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -190,9 +190,9 @@ def __init__( self, config: Lfm2Config, layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -240,7 +240,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: @@ -258,9 +258,9 @@ def __init__( self, config: Lfm2Config, layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -289,7 +289,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, **kwargs, ): if residual is None: @@ -365,8 +365,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -532,8 +532,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: hidden_states = self.model( diff --git a/vllm/model_executor/models/lfm2_moe.py b/vllm/model_executor/models/lfm2_moe.py index 728bd90be117..bb7926a9cfa9 100644 --- a/vllm/model_executor/models/lfm2_moe.py +++ b/vllm/model_executor/models/lfm2_moe.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Any, Optional +from itertools import islice +from typing import Any import torch import torch.nn as nn @@ -64,7 +65,7 @@ def __init__( self, dim: int, ff_dim: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -95,7 +96,7 @@ class Lfm2MoeSparseMoeBlock(nn.Module): def __init__( self, config: Lfm2MoeConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", enable_eplb: bool = False, ): @@ -190,10 +191,10 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -282,9 +283,9 @@ def __init__( self, config: Lfm2MoeConfig, layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", enable_eplb: bool = False, ) -> None: @@ -339,7 +340,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: @@ -357,9 +358,9 @@ def __init__( self, config: Lfm2MoeConfig, layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", enable_eplb: bool = False, ) -> None: @@ -395,7 +396,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, **kwargs, ): if residual is None: @@ -478,8 +479,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -492,7 +493,7 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer : self.end_layer]: + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, @@ -773,8 +774,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: hidden_states = self.model( diff --git a/vllm/model_executor/models/lightonocr.py b/vllm/model_executor/models/lightonocr.py new file mode 100644 index 000000000000..9839e4f8f707 --- /dev/null +++ b/vllm/model_executor/models/lightonocr.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable, Mapping, Sequence +from typing import TypeVar + +import torch +import torch.nn as nn +from transformers import ( + BatchFeature, + PixtralVisionConfig, +) + +from vllm.config import VllmConfig +from vllm.model_executor.models.mistral3 import ( + Mistral3DummyInputsBuilder, + Mistral3ForConditionalGeneration, + Mistral3MultiModalProjector, + Mistral3ProcessingInfo, + _build_mistral3_info, + init_vision_tower_for_llava, +) +from vllm.model_executor.models.pixtral import PixtralHFEncoderInfo +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder + +_I = TypeVar("_I", bound=Mistral3ProcessingInfo) + + +class LightOnOCRMultiModalProcessor(BaseMultiModalProcessor[Mistral3ProcessingInfo]): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + # NOTE: LightOnOCR does not use break/end tokens, so we remove them here. + input_ids = processed_outputs.get("input_ids") + if input_ids is not None: + processor = self.info.get_hf_processor() + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + break_id = vocab.get(processor.image_break_token) + end_id = vocab.get(processor.image_end_token) + + # create mask to remove break/end tokens + keep_mask = ~torch.isin( + input_ids, + torch.tensor([break_id, end_id]), + ) + + processed_outputs["input_ids"] = input_ids[keep_mask].unsqueeze(0) + if "attention_mask" in processed_outputs: + processed_outputs["attention_mask"] = processed_outputs[ + "attention_mask" + ][keep_mask].unsqueeze(0) + + # un-pad pixel_values per-image so caches remain independent. + pixel_values = processed_outputs.get("pixel_values") + if pixel_values is not None: + image_sizes = processed_outputs["image_sizes"] + assert len(pixel_values) == len(image_sizes) + processed_outputs["pixel_values"] = [ + p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes) + ] + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + image_token_id = hf_config.image_token_index + + assert isinstance(hf_config.vision_config, PixtralVisionConfig) + encoder_info = PixtralHFEncoderInfo(hf_config) + + def replace(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + size = images.get_image_size(item_idx) + ncols, nrows = encoder_info.get_patch_grid_size( + image_width=size.width, image_height=size.height + ) + # break/end tokens are not used in LightOnOCR + tokens = [image_token_id] * (ncols * nrows) + return PromptUpdateDetails.select_token_id(tokens, image_token_id) + + return [ + PromptReplacement( + modality="image", target=[image_token_id], replacement=replace + ) + ] + + +def _build_LightOnOCR_processor( + info: _I, + dummy_inputs: BaseDummyInputsBuilder[_I], + *, + cache: BaseMultiModalProcessorCache | None = None, +): + assert isinstance(info, Mistral3ProcessingInfo) + return LightOnOCRMultiModalProcessor(info, dummy_inputs, cache=cache) + + +@MULTIMODAL_REGISTRY.register_processor( + _build_LightOnOCR_processor, + info=_build_mistral3_info, + dummy_inputs=Mistral3DummyInputsBuilder, +) +class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.vision_encoder.": "vision_tower.", + "model.vision_projection.": "multi_modal_projector.", + "lm_head.": "language_model.lm_head.", + "model.language_model.": "language_model.model.", + } + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.vision_tower = init_vision_tower_for_llava( + config, + quant_config, + require_post_norm=False, + prefix=maybe_prefix(prefix, "vision_tower"), + ) + + self.multi_modal_projector = Mistral3MultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act, + spatial_merge_size=config.spatial_merge_size, + patch_size=config.vision_config.patch_size, + multimodal_projector_bias=config.multimodal_projector_bias, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "multi_modal_projector"), + ) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 948c9280f953..7cc908e52c88 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -26,7 +26,7 @@ from collections.abc import Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -76,7 +76,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, prefix: str = "", reduce_results: bool = True, @@ -121,12 +121,12 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, bias_o_proj: bool = False, - cache_config: Optional[CacheConfig] = None, + cache_config: CacheConfig | None = None, prefix: str = "", attn_type: str = AttentionType.DECODER, ) -> None: @@ -236,8 +236,8 @@ def forward( def _init_rotary_emb( self, config: LlamaConfig, - rope_scaling: Optional[dict[str, Any]], - quant_config: Optional[QuantizationConfig], + rope_scaling: dict[str, Any] | None, + quant_config: QuantizationConfig | None, ) -> None: is_neox_style = True is_gguf = quant_config and quant_config.get_name() == "gguf" @@ -260,7 +260,7 @@ def __init__( self, vllm_config: VllmConfig, prefix: str = "", - config: Optional[LlamaConfig] = None, + config: LlamaConfig | None = None, ) -> None: super().__init__() @@ -331,7 +331,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -346,7 +346,7 @@ def forward( hidden_states = self.mlp(hidden_states) return hidden_states, residual - def get_quant_config(self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]: + def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None: """Get quantization config for this layer. Override in subclasses.""" return vllm_config.quant_config @@ -407,13 +407,11 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[ - torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]] - ]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -627,9 +625,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: model_output = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -638,7 +636,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 075f35a098a4..33badb13fc9f 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -19,7 +19,7 @@ """Inference-only LLaMA model compatible with HuggingFace weights.""" from collections.abc import Iterable -from typing import Any, Optional +from typing import Any import torch from torch import nn @@ -33,7 +33,7 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( QKVParallelLinear, @@ -42,7 +42,6 @@ ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, @@ -149,12 +148,12 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, bias_o_proj: bool = False, - cache_config: Optional[CacheConfig] = None, + cache_config: CacheConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -297,7 +296,7 @@ def __init__( self, vllm_config: VllmConfig, prefix: str = "", - config: Optional[Llama4TextConfig] = None, + config: Llama4TextConfig | None = None, ) -> None: super().__init__() @@ -353,7 +352,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -399,7 +398,7 @@ def load_moe_expert_weights( params_dict: The dictionary of module parameters. loaded_params: The set of already loaded parameters. expert_params_mapping: The mapping of expert parameters. Must be - generated by FusedMoE.make_expert_params_mapping(). + generated by SharedFusedMoE.make_expert_params_mapping(). fused: Whether the expert weights are fused into a single weight tensor or are separate weight tensors for each expert. When fused is True, loaded_weight should have shape of: @@ -522,7 +521,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: fused_experts_params = False # Expert parameter mapping for the case where the expert weights are # not fused into a single weight tensor. - expert_params_mapping = FusedMoE.make_expert_params_mapping( + expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", @@ -530,7 +529,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ) # Expert parameter mapping for the case where the expert weights are # fused into a single weight tensor. - expert_params_mapping_fused = FusedMoE.make_expert_params_mapping( + expert_params_mapping_fused = SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_up_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="gate_up_proj", diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index 039022ef4527..90273463d64e 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -17,7 +17,6 @@ # limitations under the License. from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn @@ -49,7 +48,7 @@ def __init__( vllm_config: VllmConfig, prefix: str = "", start_layer_id: int = 0, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() self.config = vllm_config.speculative_config.draft_model_config.hf_config @@ -61,16 +60,23 @@ def __init__( prefix=maybe_prefix(prefix, "embed_tokens"), ) - self.layers = nn.ModuleList( - [ - Llama4DecoderLayer( - vllm_config=vllm_config, - prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), - config=self.config, - ) - for i in range(self.config.num_hidden_layers) - ] - ) + # Temporarily modify vllm_config.quant_config for draft model layers + original_quant_config = vllm_config.quant_config + vllm_config.quant_config = quant_config + try: + self.layers = nn.ModuleList( + [ + Llama4DecoderLayer( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + config=self.config, + ) + for i in range(self.config.num_hidden_layers) + ] + ) + finally: + # Restore original quant_config + vllm_config.quant_config = original_quant_config self.fc = torch.nn.Linear( self.config.hidden_size * 2, self.config.hidden_size, bias=False ) @@ -81,10 +87,10 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if inputs_embeds is None: inputs_embeds = self.get_input_embeddings(input_ids) @@ -136,7 +142,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded_params def validate_and_update_config( - self, start_layer_id: int, quant_config: Optional[QuantizationConfig] = None + self, start_layer_id: int, quant_config: QuantizationConfig | None = None ) -> None: # yoco and moe is not supported by draft model yet assert self.config.yoco_global_kv_layer is None @@ -193,7 +199,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: return self.model(input_ids, positions, hidden_states, inputs_embeds) diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 5df158818c9f..3617294bd621 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn @@ -13,6 +12,7 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM @@ -28,7 +28,7 @@ def __init__( vllm_config: VllmConfig, disable_input_layernorm: bool, prefix: str = "", - config: Optional[LlamaConfig] = None, + config: LlamaConfig | None = None, ) -> None: super().__init__(vllm_config, prefix=prefix, config=config) @@ -38,6 +38,17 @@ def __init__( del self.input_layernorm self.input_layernorm = nn.Identity() + def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None: + """Use drafter's quantization config instead of verifier's.""" + draft_model_config = vllm_config.speculative_config.draft_model_config + draft_load_config = vllm_config.load_config + + return ( + VllmConfig.get_quantization_config(draft_model_config, draft_load_config) + if draft_model_config + else None + ) + @support_torch_compile class LlamaModel(nn.Module): @@ -155,7 +166,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if inputs_embeds is not None: raise NotImplementedError( diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 155a4ecea28f..da4bbda186b1 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -2,12 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn from transformers import LlamaConfig +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm @@ -21,6 +21,7 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import NestedTensors from .utils import AutoWeightsLoader, maybe_prefix @@ -33,16 +34,21 @@ def __init__( self, vllm_config: VllmConfig, prefix: str = "", - config: Optional[LlamaConfig] = None, + config: LlamaConfig | None = None, + layer_idx: int = 0, ) -> None: super().__init__(vllm_config, prefix=prefix, config=config) config = config or vllm_config.model_config.hf_config quant_config = self.get_quant_config(vllm_config) + # First layer uses 2*hidden_size (embeds + hidden_states concatenated) + # Subsequent layers use hidden_size (only hidden_states, no embeds) + qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size + # override qkv self.self_attn.qkv_proj = QKVParallelLinear( - 2 * self.hidden_size, + qkv_input_size, self.self_attn.head_dim, self.self_attn.total_num_heads, self.self_attn.total_num_kv_heads, @@ -52,13 +58,14 @@ def __init__( ) self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layer_idx = layer_idx if getattr(config, "norm_before_residual", False): self._residual_norm = self._norm_before_residual else: self._residual_norm = self._norm_after_residual - def get_quant_config(self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]: + def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None: """Use drafter's quantization config instead of verifier's.""" draft_model_config = vllm_config.speculative_config.draft_model_config draft_load_config = vllm_config.load_config @@ -88,13 +95,17 @@ def forward( positions: torch.Tensor, embeds: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: - embeds = self.input_layernorm(embeds) - - hidden_states, residual = self._residual_norm(hidden_states=hidden_states) + if self.layer_idx == 0: + # First layer: concatenate embeds with hidden_states + embeds = self.input_layernorm(embeds) + hidden_states, residual = self._residual_norm(hidden_states=hidden_states) + hidden_states = torch.cat([embeds, hidden_states], dim=-1) + else: + # Subsequent layers: process hidden_states and residuals only + hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = torch.cat([embeds, hidden_states], dim=-1) # Self Attention hidden_states = self.self_attn( positions=positions, @@ -109,6 +120,15 @@ def forward( return hidden_states, residual +@support_torch_compile( + # torch.compile is disabled for multimodal EAGLE3 models due to constraint + # violations with dynamic shapes during tensor concatenation operations. + # See: https://github.com/vllm-project/vllm/pull/22872/files#r2362028132 + # Non-multimodal EAGLE3 models can still use torch.compile safely. + enable_if=lambda vllm_config: not MULTIMODAL_REGISTRY.supports_multimodal_inputs( + vllm_config.model_config + ), +) class LlamaModel(nn.Module): def __init__( self, @@ -133,9 +153,11 @@ def __init__( [ LlamaDecoderLayer( current_vllm_config, - prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), + prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"), config=self.config, + layer_idx=layer_idx, ) + for layer_idx in range(self.config.num_hidden_layers) ] ) if hasattr(self.config, "target_hidden_size"): @@ -159,20 +181,20 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - input_embeds: Optional[torch.Tensor] = None, + input_embeds: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if input_embeds is None: input_embeds = self.get_input_embeddings(input_ids) assert hidden_states.shape[-1] == input_embeds.shape[-1] residual = None - hidden_states, residual = self.layers[0]( - positions, - input_embeds, - hidden_states, - residual, - ) - + for layer in self.layers: + hidden_states, residual = layer( + positions=positions, + embeds=input_embeds, + hidden_states=hidden_states, + residual=residual, + ) hidden_states, hidden_prenorm = self.norm(hidden_states, residual) return hidden_states, hidden_prenorm @@ -245,8 +267,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[NestedTensors] = None, - is_multimodal: Optional[torch.Tensor] = None, + multimodal_embeddings: NestedTensors | None = None, + is_multimodal: torch.Tensor | None = None, ) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -255,14 +277,14 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: return self.model(input_ids, positions, hidden_states, inputs_embeds) def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) if self.draft_id_to_target_id is None: assert logits.shape[1] == self.config.vocab_size, ( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 3d46e22a0d21..a3dea0ce86f8 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -3,7 +3,7 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Final, Literal, Optional, Protocol, TypeVar, Union +from typing import Annotated, Final, Literal, Protocol, TypeAlias, TypeVar import torch import torch.nn as nn @@ -93,7 +93,7 @@ class PixtralHFImagePixelInputs(TensorSchema): type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral" pixel_values: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"}), ] @@ -110,9 +110,9 @@ class LlavaImageEmbeddingInputs(TensorSchema): data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] -LlavaImageInputs = Union[ - LlavaImagePixelInputs, PixtralHFImagePixelInputs, LlavaImageEmbeddingInputs -] +LlavaImageInputs: TypeAlias = ( + LlavaImagePixelInputs | PixtralHFImagePixelInputs | LlavaImageEmbeddingInputs +) class LlavaMultiModalProjector(nn.Module): @@ -122,7 +122,7 @@ def __init__( text_hidden_size: int, projector_hidden_act: str, multimodal_projector_bias: bool, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -154,7 +154,7 @@ class LlavaLikeConfig(Protocol): vision_config: Final[PretrainedConfig] image_token_index: Final[int] vision_feature_select_strategy: Final[str] - vision_feature_layer: Final[Union[int, list[int]]] + vision_feature_layer: Final[int | list[int]] class LlavaLikeProcessor(Protocol): @@ -172,7 +172,7 @@ def get_vision_encoder_info(self): def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor: raise NotImplementedError - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_num_image_tokens( @@ -222,7 +222,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) @@ -406,7 +406,7 @@ def _build_llava_or_pixtral_hf_processor( info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[BaseMultiModalProcessorCache] = None, + cache: BaseMultiModalProcessorCache | None = None, ) -> BaseMultiModalProcessor: if isinstance(info, PixtralHFProcessingInfo): return PixtralHFMultiModalProcessor( @@ -461,11 +461,11 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: def init_vision_tower_for_llava( hf_config: LlavaLikeConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, *, - require_post_norm: Optional[bool] = None, + require_post_norm: bool | None = None, prefix: str = "", -) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]: +) -> CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel: vision_config = hf_config.vision_config # Initialize the vision tower only up to the deepest required feature layer @@ -524,7 +524,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<image>" @@ -585,7 +585,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[LlavaImageInputs]: + ) -> LlavaImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -619,9 +619,9 @@ def _parse_and_validate_image_input( def _image_pixels_to_features( self, - vision_tower: Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel], - pixel_values: Union[torch.Tensor, list[torch.Tensor]], - ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + vision_tower: CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel, + pixel_values: torch.Tensor | list[torch.Tensor], + ) -> torch.Tensor | tuple[torch.Tensor, ...]: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower return vision_tower( @@ -631,8 +631,8 @@ def _image_pixels_to_features( def _process_image_pixels( self, - inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs], - ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + inputs: LlavaImagePixelInputs | PixtralHFImagePixelInputs, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: assert self.vision_tower is not None pixel_values = inputs["pixel_values"] @@ -642,7 +642,7 @@ def _process_image_pixels( def _process_image_input( self, image_input: LlavaImageInputs, - ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + ) -> torch.Tensor | tuple[torch.Tensor, ...]: if image_input["type"] == "image_embeds": return image_input["data"] @@ -672,10 +672,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: """Run forward pass for LLaVA-1.5. One key thing to understand is the `input_ids` already accounts for the @@ -725,7 +725,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -754,11 +754,11 @@ def get_hf_processor(self, **kwargs: object): class MantisMultiModalProcessor(LlavaMultiModalProcessor): def apply( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Optional[Mapping[str, object]] = None, - mm_uuids: Optional[MultiModalUUIDDict] = None, + tokenization_kwargs: Mapping[str, object] | None = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> MultiModalInputs: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index caedace7cab1..3cf546644d04 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -3,7 +3,7 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping -from typing import Annotated, Final, Literal, Optional, Protocol, TypeVar, Union +from typing import Annotated, Final, Literal, Protocol, TypeAlias, TypeVar import torch import torch.nn as nn @@ -55,11 +55,11 @@ class LlavaNextImagePixelInputs(TensorSchema): type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"}), ] - image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] + image_sizes: Annotated[torch.Tensor | None, TensorShape("bn", 2)] # This should be in `(height, width)` format. @@ -75,7 +75,9 @@ class LlavaNextImageEmbeddingInputs(TensorSchema): data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] -LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, LlavaNextImageEmbeddingInputs] +LlavaNextImageInputs: TypeAlias = ( + LlavaNextImagePixelInputs | LlavaNextImageEmbeddingInputs +) class LlavaNextLikeConfig(LlavaLikeConfig, Protocol): @@ -235,7 +237,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<image>" @@ -294,7 +296,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[LlavaNextImageInputs]: + ) -> LlavaNextImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) @@ -324,7 +326,7 @@ def _parse_and_validate_image_input( def _image_pixels_to_features( self, - vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + vision_tower: CLIPVisionModel | SiglipVisionModel, pixel_values: torch.Tensor, ) -> torch.Tensor: # NOTE: we skip the step to select the vision feature layer since @@ -424,7 +426,7 @@ def _merge_image_patch_embeddings( def _process_image_pixels( self, inputs: LlavaNextImagePixelInputs, - ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + ) -> torch.Tensor | tuple[torch.Tensor, ...]: assert self.vision_tower is not None pixel_values = inputs["pixel_values"] @@ -456,7 +458,7 @@ def _process_image_pixels( def _process_image_input( self, image_input: LlavaNextImageInputs, - ) -> Union[torch.Tensor, list[torch.Tensor]]: + ) -> torch.Tensor | list[torch.Tensor]: if image_input["type"] == "image_embeds": return [image_input["data"]] @@ -491,9 +493,9 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, *, - is_multimodal: Optional[torch.Tensor] = None, + is_multimodal: torch.Tensor | None = None, # Multi-modal token ID may exceed vocab size handle_oov_mm_token: bool = True, ) -> torch.Tensor: @@ -512,10 +514,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: """Run forward pass for LlaVA-NeXT. One key thing to understand is the `input_ids` already accounts for the @@ -573,7 +575,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 074acc7943a4..77c331b0182b 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal import torch import torch.nn as nn @@ -33,7 +33,7 @@ ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of +from vllm.utils.collection_utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -66,7 +66,7 @@ class LlavaNextVideoPixelInputs(TensorSchema): type: Literal["pixel_values_videos"] = "pixel_values_videos" pixel_values_videos: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("bn", "f", 3, "h", "w", dynamic_dims={"f"}), ] @@ -81,7 +81,7 @@ def get_vision_encoder_info(self): def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor(LlavaNextVideoProcessor, **kwargs) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"video": 1} def get_image_size_with_most_features(self) -> ImageSize: @@ -165,7 +165,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_videos = mm_counts.get("video", 0) @@ -313,7 +313,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<image>" if modality.startswith("video"): @@ -356,7 +356,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def _parse_and_validate_video_input( self, **kwargs: object - ) -> Optional[LlavaNextVideoPixelInputs]: + ) -> LlavaNextVideoPixelInputs | None: """ A legal video input should have the following dimensions: { @@ -381,7 +381,7 @@ def _parse_and_validate_video_input( def _video_pixels_to_features( self, - vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + vision_tower: CLIPVisionModel | SiglipVisionModel, pixel_values: torch.Tensor, ) -> torch.Tensor: # NOTE: we skip the step to select the vision feature layer since @@ -433,10 +433,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: """Run forward pass for LlaVA-NeXT-Video. Args: input_ids: Flattened (concatenated) input_ids corresponding to a @@ -455,7 +455,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 05f1621694c3..c4cae240ea46 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Final, Literal, Optional, Protocol, Union +from typing import Annotated, Final, Literal, Protocol, TypeAlias import torch import torch.nn as nn @@ -69,7 +69,7 @@ class LlavaOnevisionVideoPixelInputs(TensorSchema): type: Literal["pixel_values_videos"] = "pixel_values_videos" pixel_values_videos: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("bn", "f", 3, "h", "w", dynamic_dims={"f"}), ] @@ -90,11 +90,11 @@ class LlavaOnevisionImagePixelInputs(TensorSchema): type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"}), ] - image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] + image_sizes: Annotated[torch.Tensor | None, TensorShape("bn", 2)] class LlavaOnevisionImageEmbeddingInputs(TensorSchema): @@ -113,13 +113,13 @@ class LlavaOnevisionImageEmbeddingInputs(TensorSchema): ] -LlavaOnevisionImageInputs = Union[ - LlavaOnevisionImagePixelInputs, LlavaOnevisionImageEmbeddingInputs -] +LlavaOnevisionImageInputs: TypeAlias = ( + LlavaOnevisionImagePixelInputs | LlavaOnevisionImageEmbeddingInputs +) -LlavaOnevisionMultiInputs = Union[ - LlavaOnevisionImageInputs, LlavaOnevisionVideoPixelInputs -] +LlavaOnevisionMultiInputs: TypeAlias = ( + LlavaOnevisionImageInputs | LlavaOnevisionVideoPixelInputs +) class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol): @@ -133,7 +133,7 @@ def get_hf_config(self) -> LlavaOnevisionLikeConfig: def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor(LlavaOnevisionProcessor, **kwargs) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None, "video": None} # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 @@ -276,7 +276,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -493,7 +493,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<image>" if modality.startswith("video"): @@ -531,7 +531,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[LlavaOnevisionImageInputs]: + ) -> LlavaOnevisionImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) @@ -560,7 +560,7 @@ def _parse_and_validate_image_input( def _parse_and_validate_video_input( self, **kwargs: object - ) -> Optional[LlavaOnevisionVideoPixelInputs]: + ) -> LlavaOnevisionVideoPixelInputs | None: """ A legal video input should have the following dimensions: { @@ -606,7 +606,7 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _image_pixels_to_features( self, - vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + vision_tower: CLIPVisionModel | SiglipVisionModel, pixel_values: torch.Tensor, ) -> torch.Tensor: # NOTE: we skip the step to select the vision feature layer since @@ -726,7 +726,7 @@ def _merge_image_patch_embeddings( def _process_image_pixels( self, inputs: LlavaOnevisionImagePixelInputs, - ) -> Union[torch.Tensor, list[torch.Tensor]]: + ) -> torch.Tensor | list[torch.Tensor]: assert self.vision_tower is not None pixel_values = inputs["pixel_values"] @@ -761,7 +761,7 @@ def _process_image_pixels( def _process_image_input( self, image_input: LlavaOnevisionImageInputs, - ) -> Union[torch.Tensor, list[torch.Tensor]]: + ) -> torch.Tensor | list[torch.Tensor]: if image_input["type"] == "image_embeds": return [image_input["data"]] @@ -788,7 +788,7 @@ def _process_image_input( def _video_pixels_to_features( self, - vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + vision_tower: CLIPVisionModel | SiglipVisionModel, pixel_values: torch.Tensor, ) -> torch.Tensor: # NOTE: we skip the step to select the vision feature layer since @@ -881,8 +881,8 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: for modality in mm_input_by_modality: multimodal_input = mm_input_by_modality[modality] if modality == "image": - vision_embeddings = self._process_image_input(multimodal_input) - multimodal_embeddings += tuple(vision_embeddings) + image_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += tuple(image_embeddings) if modality == "video": video_embeddings = self._process_video_pixels(multimodal_input) multimodal_embeddings += tuple(video_embeddings) @@ -893,10 +893,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: """Run forward pass for LlaVA-Onevision. Args: input_ids: Flattened (concatenated) input_ids corresponding to a @@ -915,7 +915,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/longcat_flash.py b/vllm/model_executor/models/longcat_flash.py index 5020da37df89..5671347c00a2 100644 --- a/vllm/model_executor/models/longcat_flash.py +++ b/vllm/model_executor/models/longcat_flash.py @@ -35,7 +35,7 @@ import typing from collections.abc import Callable, Iterable -from typing import Optional, Union +from itertools import islice import torch from torch import nn @@ -114,7 +114,7 @@ def __init__( attention_dropout=0.0, mla_scale_q_lora=False, mla_scale_kv_lora=False, - torch_dtype="bfloat16", + dtype="bfloat16", params_dtype="bfloat16", router_dtype="float32", router_bias=False, @@ -130,7 +130,7 @@ def __init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, - torch_dtype=torch_dtype, + dtype=dtype, params_dtype=params_dtype, router_dtype=router_dtype, topk_method=topk_method, @@ -193,7 +193,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, reduce_results: bool = True, prefix: str = "", ) -> None: @@ -269,8 +269,8 @@ def __init__( top_k: int, hidden_size: int, intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", enable_eplb: bool = False, ): @@ -328,8 +328,8 @@ def __init__( self, vllm_config: VllmConfig, config: FlashConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", enable_eplb: bool = False, ) -> None: @@ -414,7 +414,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states @@ -505,9 +505,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -519,8 +519,7 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, @@ -592,9 +591,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -603,7 +602,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/longcat_flash_mtp.py b/vllm/model_executor/models/longcat_flash_mtp.py index 55468f354c3a..e554d1e2de92 100644 --- a/vllm/model_executor/models/longcat_flash_mtp.py +++ b/vllm/model_executor/models/longcat_flash_mtp.py @@ -4,7 +4,6 @@ # Adapted from # https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/deepseek_mtp.py from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn @@ -35,7 +34,7 @@ def __init__( config: PretrainedConfig, prefix: str, vllm_config: VllmConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -55,7 +54,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, previous_hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, spec_step_index: int = 0, ) -> torch.Tensor: assert inputs_embeds is not None @@ -78,7 +77,7 @@ def __init__( self, *, vllm_config: VllmConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -110,7 +109,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, previous_hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, spec_step_idx: int = 0, ) -> torch.Tensor: if inputs_embeds is None: @@ -155,8 +154,8 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, spec_step_idx: int = 0, ) -> torch.Tensor: hidden_states = self.model( @@ -168,7 +167,7 @@ def compute_logits( self, hidden_states: torch.Tensor, spec_step_idx: int = 0, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits @@ -344,7 +343,7 @@ def _rewrite_spec_layer_name( def get_spec_layer_idx_from_weight_name( self, config: PretrainedConfig, weight_name: str - ) -> Optional[int]: + ) -> int | None: if "model.mtp" in weight_name: return config.num_hidden_layers * 2 return None diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index fa11f92cce33..fb145289fbfe 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -3,7 +3,7 @@ """PyTorch MAMBA model.""" from collections.abc import Iterable -from typing import Optional +from itertools import islice import torch from torch import nn @@ -48,10 +48,10 @@ class MambaDecoderLayer(nn.Module): def __init__( self, config: MambaConfig, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - is_lora_enabled: Optional[bool] = False, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + is_lora_enabled: bool | None = False, prefix: str = "", ) -> None: super().__init__() @@ -82,7 +82,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, **kwargs, ): if residual is None: @@ -148,8 +148,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -162,8 +162,7 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] + for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual ) @@ -244,8 +243,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, ): hidden_states = self.backbone( diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index 4491648f3a0a..5eb21b966e18 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -3,7 +3,6 @@ """PyTorch MAMBA2 model.""" from collections.abc import Iterable -from typing import Optional import torch from torch import nn @@ -44,9 +43,9 @@ class Mamba2DecoderLayer(nn.Module): def __init__( self, config: MambaConfig, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -76,7 +75,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, **kwargs, ): if residual is None: @@ -142,8 +141,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -277,8 +276,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, ): hidden_states = self.backbone( diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py index 47839a2c6b03..322cce79d4cb 100644 --- a/vllm/model_executor/models/midashenglm.py +++ b/vllm/model_executor/models/midashenglm.py @@ -25,8 +25,8 @@ import collections import collections.abc -from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Callable, Optional, TypedDict, Union, cast +from collections.abc import Callable, Iterable, Mapping, Sequence +from typing import Annotated, Any, TypeAlias, cast import numpy as np import torch @@ -62,11 +62,12 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.midashenglm import DashengConfig +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix -_Tuple2 = Union[int, tuple[int, int], Sequence[int]] +_Tuple2: TypeAlias = int | tuple[int, int] | Sequence[int] def _resolve_tuple2(x: _Tuple2) -> tuple[int, int]: @@ -105,7 +106,7 @@ def __init__( patch_stride: _Tuple2 = 16, in_chans: int = 1, embed_dim: int = 768, - norm_layer: Optional[Callable] = None, + norm_layer: Callable | None = None, flatten: bool = False, ): super().__init__() @@ -151,9 +152,9 @@ class DashengMlp(nn.Module): def __init__( self, in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, + hidden_features: int | None = None, + out_features: int | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -186,7 +187,7 @@ def __init__( dim: int, num_heads: int = 8, qkv_bias: bool = False, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -226,7 +227,7 @@ def __init__( prefix=f"{prefix}.proj", ) - def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): + def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None): B, N, C = x.shape qkv, _ = self.qkv(x) @@ -253,8 +254,8 @@ def __init__( num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = False, - init_values: Optional[float] = None, - quant_config: Optional[QuantizationConfig] = None, + init_values: float | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -285,7 +286,7 @@ def __init__( def forward( self, x: torch.Tensor, - mask: Optional[torch.Tensor] = None, + mask: torch.Tensor | None = None, ) -> torch.Tensor: x = x + self.ls1(self.attn(self.norm1(x), mask)) x = x + self.ls2(self.mlp(self.norm2(x))) @@ -349,7 +350,7 @@ class DashengAudioTransformer(nn.Module): def __init__( self, config: DashengConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -393,7 +394,7 @@ def __init__( def forward_features( self, x: torch.Tensor, - mask: Optional[torch.Tensor] = None, + mask: torch.Tensor | None = None, ) -> torch.Tensor: t = x.shape[-1] x = x + self.time_pos_embed[:, :, :, :t] @@ -418,8 +419,8 @@ def _to_mask(self, lengths: torch.Tensor, max_length: int) -> torch.Tensor: def forward( self, x: torch.Tensor, - x_length: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + x_length: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: x = self.front_end(x) x = x.to(self.time_pos_embed.dtype) target_length_in_patches = self.target_length // 4 @@ -462,8 +463,8 @@ def __init__( in_dim: int, out_dim: int, downsample_rate=5, - dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, + dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -508,11 +509,16 @@ def forward(self, x, mask=None): # === Audio Inputs === # -class MiDashengLMAudioInputs(TypedDict): - input_values: torch.Tensor - """Shape: `(num_audios, num_sampling_points)`""" - audio_length: torch.Tensor - """Shape: `(num_audios, 1)`""" +class MiDashengLMAudioInputs(TensorSchema): + """ + + Dimensions: + - bn: Batch size * number of audios + - p: Number of sampling points + """ + + input_values: Annotated[torch.Tensor, TensorShape("n", "p")] + audio_length: Annotated[torch.Tensor, TensorShape("n")] class MiDashengLMProcessingInfo(BaseProcessingInfo): @@ -524,7 +530,7 @@ def get_feature_extractor(self): feature_extractor = hf_processor.feature_extractor return feature_extractor - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"audio": None} def get_min_audio_len(self): @@ -550,7 +556,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) @@ -676,6 +682,8 @@ def get_replacement_midashenglm(item_idx: int): dummy_inputs=MiDashengLMDummyInputsBuilder, ) class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -689,7 +697,7 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): } @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("audio"): return "<|audio_bos|><|AUDIO|><|audio_eos|>" @@ -728,44 +736,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.decoder.make_empty_intermediate_tensors ) - def _validate_and_reshape_mm_tensor( - self, mm_input: object, name: str - ) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - return mm_input.reshape(-1, *mm_input.shape[2:]) - - if name == "input_values": - max_length = max(tensor.shape[1] for tensor in mm_input) - padded_mm_input = [ - torch.nn.functional.pad(tensor, (0, max_length - tensor.shape[1])) - if tensor.shape[1] < max_length - else tensor - for tensor in mm_input - ] - return torch.concat(padded_mm_input) - - return torch.concat(mm_input) - def _parse_and_validate_audio_input( self, **kwargs: object - ) -> Optional[MiDashengLMAudioInputs]: + ) -> MiDashengLMAudioInputs | None: input_values = kwargs.pop("input_values", None) audio_length = kwargs.pop("audio_length", None) if input_values is None: return None - input_values = self._validate_and_reshape_mm_tensor( - input_values, "input_values" - ) - audio_length = self._validate_and_reshape_mm_tensor( - audio_length, "audio_length" - ) - if not isinstance(input_values, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of audio input features. " - f"Got type: {type(input_values)}" + + if isinstance(input_values, list): + input_values = torch.nn.utils.rnn.pad_sequence( + input_values, + batch_first=True, ) return MiDashengLMAudioInputs( @@ -773,7 +756,10 @@ def _parse_and_validate_audio_input( audio_length=audio_length, ) - def _process_audio_input(self, audio_input: MiDashengLMAudioInputs) -> torch.Tensor: + def _process_audio_input( + self, + audio_input: MiDashengLMAudioInputs, + ) -> tuple[torch.Tensor, ...]: # Process audio through encoder and projector input_values = audio_input["input_values"] audio_length = audio_input["audio_length"] @@ -783,17 +769,13 @@ def _process_audio_input(self, audio_input: MiDashengLMAudioInputs) -> torch.Ten audio_embeddings = audio_embeddings.to(audio_input["input_values"].dtype) batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape - audio_length_np = ( - audio_length.cpu().numpy() - if isinstance(audio_length, torch.Tensor) - else audio_length - ) audio_output_lengths = [ max(1, calculate_mel_frames_dasheng(int(length))) # at least one frame - for length in audio_length_np + for length in audio_length.tolist() ] - audio_output_lengths = torch.tensor(audio_output_lengths).to( - audio_embeddings.device + audio_output_lengths = torch.tensor( + audio_output_lengths, + device=audio_embeddings.device, ) audio_feature_mask = torch.arange( @@ -820,20 +802,12 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None - elif inputs_embeds is None: - multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) - inputs_embeds = self.get_input_embeddings( - input_ids, - multimodal_embeddings, - is_multimodal=input_ids == self.config.audio_token_id, - ) - input_ids = None return self.decoder.model( input_ids, @@ -845,7 +819,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.decoder.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/mimo.py b/vllm/model_executor/models/mimo.py index e01e06421842..726752a77e0d 100644 --- a/vllm/model_executor/models/mimo.py +++ b/vllm/model_executor/models/mimo.py @@ -28,7 +28,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch import torch.nn as nn @@ -64,9 +63,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -185,7 +184,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: hidden_states = self.model.norm(hidden_states) logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/mimo_mtp.py b/vllm/model_executor/models/mimo_mtp.py index b678a06b7f20..3d7695a2a304 100644 --- a/vllm/model_executor/models/mimo_mtp.py +++ b/vllm/model_executor/models/mimo_mtp.py @@ -21,7 +21,6 @@ """Inference-only MiMo-MTP model.""" from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn @@ -48,8 +47,8 @@ def __init__( config: PretrainedConfig, prefix: str, model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() @@ -129,7 +128,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, previous_hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, spec_step_idx: int = 0, ) -> torch.Tensor: if inputs_embeds is None: @@ -173,8 +172,8 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, spec_step_idx: int = 0, ) -> torch.Tensor: assert spec_step_idx == 0, "mimo_mtp only support predict one token now" @@ -187,7 +186,7 @@ def compute_logits( self, hidden_states: torch.Tensor, spec_step_idx: int = 0, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.model.compute_logits(hidden_states, self.lm_head, spec_step_idx) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 06cb6bc61576..09328b472248 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -27,7 +27,7 @@ import math from collections.abc import Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -89,8 +89,8 @@ def __init__( top_k: int, hidden_size: int, intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, + params_dtype: torch.dtype | None = None, + tp_size: int | None = None, ): super().__init__() self.tp_size = tp_size or get_tensor_model_parallel_world_size() @@ -190,7 +190,7 @@ def __init__( intermediate_size: int, hidden_act: str, hidden_act_param: float, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -223,10 +223,10 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -305,8 +305,8 @@ class MiniCPMDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -362,7 +362,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states @@ -425,8 +425,8 @@ def _init_layers( self, prefix: str, config: PretrainedConfig, - cache_config: Optional[CacheConfig], - quant_config: Optional[QuantizationConfig], + cache_config: CacheConfig | None, + quant_config: QuantizationConfig | None, ): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, @@ -444,11 +444,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[ - torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]] - ]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -633,11 +631,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[ - torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]] - ]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: model_output = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -658,7 +654,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index 35f02a1538e8..ab4fe36476b9 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -25,7 +25,7 @@ # limitations under the License. """Inference-only MiniCPM3 model compatible with HuggingFace weights.""" -from typing import Any, Optional +from typing import Any import torch from torch import nn @@ -63,10 +63,10 @@ def __init__( q_lora_rank: int, kv_lora_rank: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -214,8 +214,8 @@ def _init_layers( self, prefix: str, config: PretrainedConfig, - cache_config: Optional[CacheConfig], - quant_config: Optional[QuantizationConfig], + cache_config: CacheConfig | None, + quant_config: QuantizationConfig | None, ): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, diff --git a/vllm/model_executor/models/minicpm_eagle.py b/vllm/model_executor/models/minicpm_eagle.py index 6c635b248109..463af9bbe139 100644 --- a/vllm/model_executor/models/minicpm_eagle.py +++ b/vllm/model_executor/models/minicpm_eagle.py @@ -26,7 +26,6 @@ import math from collections.abc import Iterable -from typing import Optional, Union import torch from torch import nn @@ -61,8 +60,8 @@ class EagleMiniCPMDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -118,7 +117,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states @@ -185,8 +184,8 @@ def _init_layers( self, prefix: str, config: PretrainedConfig, - cache_config: Optional[CacheConfig], - quant_config: Optional[QuantizationConfig], + cache_config: CacheConfig | None, + quant_config: QuantizationConfig | None, start_layer: int, ): self.eagle_layers = nn.ModuleList( @@ -210,7 +209,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: input_embeds = self.get_input_embeddings(input_ids) input_embeds = self.input_norm1(input_embeds) hidden_states = self.input_norm2(hidden_states) @@ -389,7 +388,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index 34f05122abe3..fa2feb0ba10b 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -24,8 +24,8 @@ # limitations under the License. """Inference-only MiniCPM-O model compatible with HuggingFace weights.""" -from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Callable, Literal, Optional, Union +from collections.abc import Callable, Iterable, Mapping, Sequence +from typing import Annotated, Any, Literal, TypeAlias import torch from torch import nn @@ -71,7 +71,7 @@ MiniCPMVProcessingInfo, _minicpmv_field_config, ) -from .utils import AutoWeightsLoader, cast_overflow_tensors, flatten_bn, maybe_prefix +from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix CPU_DEVICE = torch.device("cpu") @@ -89,7 +89,7 @@ class MiniCPMOAudioFeatureInputs(TensorSchema): type: Literal["audio_features"] = "audio_features" audio_features: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("bns", "c", "l", dynamic_dims={"l"}), ] """ @@ -99,7 +99,7 @@ class MiniCPMOAudioFeatureInputs(TensorSchema): """ audio_feature_lens: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("bn", "s"), ] """ @@ -121,24 +121,22 @@ class MiniCPMOAudioEmbeddingInputs(TensorSchema): type: Literal["audio_embeds"] = "audio_embeds" audio_embeds: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("bn", "s", "h", dynamic_dims={"s"}), ] -MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, MiniCPMOAudioEmbeddingInputs] +MiniCPMOAudioInputs: TypeAlias = ( + MiniCPMOAudioFeatureInputs | MiniCPMOAudioEmbeddingInputs +) def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): - audio_features = hf_inputs.get("audio_features", torch.empty(0)) - num_audios = len(audio_features) - return dict( **_minicpmv_field_config(hf_inputs), audio_features=MultiModalFieldConfig.batched("audio"), audio_feature_lens=MultiModalFieldConfig.batched("audio"), audio_embeds=MultiModalFieldConfig.batched("audio"), - audio_token_id=MultiModalFieldConfig.shared("audio", num_audios), ) @@ -162,8 +160,8 @@ def __init__( class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser): def _parse_audio_data( self, - data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]], - ) -> Optional[ModalityDataItems[Any, Any]]: + data: dict[str, torch.Tensor] | ModalityData[AudioItem], + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return MiniCPMOAudioEmbeddingItems( data, @@ -176,7 +174,7 @@ def _parse_audio_data( class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): audio_pattern = "(<audio>./</audio>)" - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {**super().get_supported_mm_limits(), "audio": None} def get_audio_placeholder( @@ -253,7 +251,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) audio_len = ( @@ -330,10 +328,6 @@ def process_audios( ] audio_inputs["audio_features"] = unpadded_audio_features - tokenizer = self.info.get_tokenizer() - unk_token_id = tokenizer.get_vocab()["<unk>"] - audio_inputs["audio_token_id"] = torch.tensor(unk_token_id) - return audio_inputs def process_mm_inputs( @@ -434,12 +428,10 @@ def forward( attention_mask: torch.Tensor, ) -> torch.Tensor: residual = hidden_states - past_key_values = None hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, past_key_values = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, - past_key_value=past_key_values, ) hidden_states = nn.functional.dropout( hidden_states, p=self.dropout, training=self.training @@ -479,7 +471,7 @@ def __init__(self, config: WhisperConfig): def forward( self, input_features: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: torch.Tensor | None = None, ) -> BaseModelOutputWithPast: # Ignore copy input_features = input_features.to( @@ -549,7 +541,7 @@ class MiniCPMO(MiniCPMV2_6): } @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "(<image>./</image>)" if modality.startswith("video"): @@ -565,8 +557,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm") ) - self.audio_token_id = None - def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""): # Do not use parameters temporarily audio_config = self.config.audio_config @@ -722,50 +712,25 @@ def get_audio_hidden_states( def _parse_and_validate_audio_input( self, **kwargs: object - ) -> Optional[MiniCPMOAudioInputs]: + ) -> MiniCPMOAudioInputs | None: audio_features = kwargs.pop("audio_features", None) audio_embeds = kwargs.pop("audio_embeds", None) if audio_features is None and audio_embeds is None: return None - audio_token_id = kwargs.pop("audio_token_id") - if audio_token_id is not None: - assert isinstance(audio_token_id, torch.Tensor) - self.mm_token_ids.add(audio_token_id.flatten().unique().item()) - if audio_embeds is not None: - if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of audio_embeds. Got type: {type(audio_embeds)}" - ) - - audio_embeds_flat = flatten_bn(audio_embeds) - return MiniCPMOAudioEmbeddingInputs( type="audio_embeds", - audio_embeds=audio_embeds_flat, - ) - - if not isinstance(audio_features, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of audio_features. Got type: {type(audio_features)}" + audio_embeds=audio_embeds, ) audio_feature_lens = kwargs.pop("audio_feature_lens") - if not isinstance(audio_feature_lens, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of audio_feature_lens. " - f"Got type: {type(audio_feature_lens)}" - ) - - audio_features_flat = flatten_bn(audio_features) - audio_feature_lens_flat = flatten_bn(audio_feature_lens) return MiniCPMOAudioFeatureInputs( type="audio_features", - audio_features=audio_features_flat, - audio_feature_lens=audio_feature_lens_flat, + audio_features=audio_features, + audio_feature_lens=audio_feature_lens, ) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: @@ -785,7 +750,7 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _process_audio_input( self, audio_input: MiniCPMOAudioInputs, - ) -> Union[torch.Tensor, list[torch.Tensor]]: + ) -> torch.Tensor | list[torch.Tensor]: if audio_input["type"] == "audio_embeds": return audio_input["audio_embeds"] @@ -797,7 +762,7 @@ def _process_multimodal_inputs(self, modalities: dict): for modality in modalities: if modality == "audios": audio_input = modalities["audios"] - audio_features = self._process_audio_input(audio_input) - multimodal_embeddings += tuple(audio_features) + audio_embeddings = self._process_audio_input(audio_input) + multimodal_embeddings += tuple(audio_embeddings) return multimodal_embeddings diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 09f973e98db9..09937706f8c5 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -26,10 +26,10 @@ import math from collections import defaultdict -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial from itertools import chain -from typing import Annotated, Any, Callable, Literal, Optional, Union +from typing import Annotated, Any, Literal, TypeAlias import numpy as np import torch @@ -42,14 +42,11 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.resampler import ( BaseResampler, Resampler2, get_2d_sincos_pos_embed, ) -from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -86,8 +83,9 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import flatten_2d_lists +from vllm.utils.collection_utils import flatten_2d_lists from vllm.utils.tensor_schema import TensorSchema, TensorShape +from vllm.utils.torch_utils import set_default_torch_dtype from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import ( @@ -114,7 +112,7 @@ class MiniCPMVImagePixelInputs(TensorSchema): type: Literal["pixel_values"] = "pixel_values" - # Note that the image size may vary, so we pass it as a list instead of a + # Note that the patch size may vary, so we pass it as a list instead of a # batched tensor. pixel_values: Annotated[ list[torch.Tensor], @@ -140,12 +138,12 @@ class MiniCPMVImageEmbeddingInputs(TensorSchema): type: Literal["image_embeds"] image_embeds: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("bn", "ns", "hs"), ] -MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageEmbeddingInputs] +MiniCPMVImageInputs: TypeAlias = MiniCPMVImagePixelInputs | MiniCPMVImageEmbeddingInputs DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) @@ -156,10 +154,10 @@ def __init__( num_queries: int, embed_dim: int, num_heads: int, - kv_dim: Optional[int] = None, + kv_dim: int | None = None, norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, max_size: tuple[int, int] = (70, 70), - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__( @@ -251,11 +249,11 @@ def __init__( num_queries: int, embed_dim: int, num_heads: int, - kv_dim: Optional[int] = None, + kv_dim: int | None = None, norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, max_size: tuple[int, int] = (70, 70), max_temporal_size: int = 36000, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__( @@ -318,7 +316,7 @@ def _adjust_temporal_pos_cache( self.max_temporal_size = max_temporal_size self._set_temporal_pos_cache(self.max_temporal_size, device) - def _init_weights(self, m: Union[nn.Linear, nn.LayerNorm]): + def _init_weights(self, m: nn.Linear | nn.LayerNorm): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: @@ -453,12 +451,6 @@ def get_version_by_config(config: PretrainedConfig) -> tuple[int, ...]: def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]): - pixel_values = hf_inputs.get("pixel_values", torch.empty(0)) - num_images = len(pixel_values) - - video_pixel_values = hf_inputs.get("video_pixel_values", torch.empty(0)) - num_videos = len(video_pixel_values) - return dict( pixel_values=MultiModalFieldConfig.batched("image"), image_sizes=MultiModalFieldConfig.batched("image"), @@ -468,8 +460,6 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]): video_image_sizes=MultiModalFieldConfig.batched("video"), video_tgt_sizes=MultiModalFieldConfig.batched("video"), video_embeds=MultiModalFieldConfig.batched("video"), - image_token_id=MultiModalFieldConfig.shared("image", num_images), - video_token_id=MultiModalFieldConfig.shared("video", num_videos), ) @@ -521,8 +511,8 @@ def get_num_frames(self, index: int) -> int: class MiniCPMVMultiModalDataParser(MultiModalDataParser): def _parse_image_data( self, - data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], - ) -> Optional[ModalityDataItems[Any, Any]]: + data: dict[str, torch.Tensor] | ModalityData[ImageItem], + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return MiniCPMVImageEmbeddingItems( data, @@ -533,8 +523,8 @@ def _parse_image_data( def _parse_video_data( self, - data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], - ) -> Optional[ModalityDataItems[Any, Any]]: + data: dict[str, torch.Tensor] | ModalityData[VideoItem], + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return MiniCPMVVideoEmbeddingItems( data, @@ -570,7 +560,7 @@ def get_image_processor(self, **kwargs: object): def get_model_version(self): return get_version_by_config(self.get_hf_config()) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: mm_limits = {"image": None} if self.get_model_version() in {(2, 6), (4, 0), (4, 5)}: mm_limits["video"] = None @@ -582,7 +572,7 @@ def get_slice_image_placeholder( image_size: ImageSize, # For MiniCPM V/O 2.6 image_idx: int = 0, - max_slice_nums: Optional[int] = None, + max_slice_nums: int | None = None, use_image_id: bool = True, ) -> str: image_processor = self.get_image_processor() @@ -602,8 +592,8 @@ def get_sliced_grid( self, image_size: ImageSize, # For MiniCPM V/O 2.6 - max_slice_nums: Optional[int] = None, - ) -> Optional[tuple[int, int]]: + max_slice_nums: int | None = None, + ) -> tuple[int, int] | None: image_processor = self.get_image_processor() version = self.get_model_version() @@ -621,7 +611,7 @@ def get_sliced_grid( def get_num_image_tokens( self, image_size: ImageSize, - max_slice_nums: Optional[int] = None, + max_slice_nums: int | None = None, ) -> int: image_processor = self.get_image_processor() @@ -712,7 +702,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -792,10 +782,6 @@ def process_images( out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, ) - tokenizer = self.info.get_tokenizer() - unk_token_id = tokenizer.get_vocab()["<unk>"] - image_inputs["image_token_id"] = torch.tensor(unk_token_id) - return image_inputs def process_videos( @@ -831,10 +817,6 @@ def process_videos( video_inputs = {f"video_{k}": v for k, v in video_inputs.items()} - tokenizer = self.info.get_tokenizer() - unk_token_id = tokenizer.get_vocab()["<unk>"] - video_inputs["video_token_id"] = torch.tensor(unk_token_id) - return video_inputs def process_mm_inputs( @@ -1021,10 +1003,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): instantiated. """ + merge_by_field_config = True + supports_encoder_tp_data = True @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "(<image>./</image>)" if modality.startswith("video"): @@ -1066,57 +1050,30 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "resampler"), ) - self.mm_token_ids = set[int]() self.make_empty_intermediate_tensors = self.llm.make_empty_intermediate_tensors def _parse_and_validate_vision_input( self, modality: str, **kwargs: object, - ) -> Optional[MiniCPMVImageInputs]: + ) -> MiniCPMVImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: return None - image_token_id = kwargs.pop("image_token_id") - if image_token_id is not None: - assert isinstance(image_token_id, torch.Tensor) - self.mm_token_ids.add(image_token_id.flatten().unique().item()) - if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of image_embeds for {modality=}. " - f"Got type: {type(image_embeds)}" - ) - - image_embeds_flat = flatten_bn(image_embeds) - return MiniCPMVImageEmbeddingInputs( type="image_embeds", - image_embeds=image_embeds_flat, - ) - - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of pixel_values for {modality=}. " - f"Got type: {type(pixel_values)}" + image_embeds=image_embeds, ) tgt_sizes = kwargs.pop("tgt_sizes") - if not isinstance(tgt_sizes, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of tgt_sizes for {modality=}. " - f"Got type: {type(tgt_sizes)}" - ) - - num_slices = [[len(p) for p in ps] for ps in pixel_values] - num_slices_flat = flatten_bn(torch.tensor(num_slices)) - pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values)) - tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True) + num_slices_flat = torch.tensor([len(ps) for ps in pixel_values]) + pixel_values_flat = flatten_bn(pixel_values) + tgt_sizes_flat = flatten_bn(tgt_sizes, concat=True) return MiniCPMVImagePixelInputs( type="pixel_values", @@ -1142,15 +1099,8 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: input_key in ("video_pixel_values", "video_embeds") and "videos" not in modalities ): - - def _image_key(video_key: str): - if video_key == "video_token_id": - return "image_token_id" - - return video_key.removeprefix("video_") - modalities["videos"] = self._parse_and_validate_vision_input( - "videos", **{_image_key(k): v for k, v in kwargs.items()} + "videos", **{k.removeprefix("video_"): v for k, v in kwargs.items()} ) return modalities @@ -1158,7 +1108,7 @@ def _image_key(video_key: str): def _process_vision_input( self, image_input: MiniCPMVImageInputs, - ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]: + ) -> torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...]: if image_input["type"] == "image_embeds": return image_input["image_embeds"] @@ -1177,12 +1127,12 @@ def _process_multimodal_inputs(self, modalities: dict): for modality in modalities: if modality == "images": image_input = modalities["images"] - image_features = self._process_vision_input(image_input) - multimodal_embeddings += tuple(image_features) + image_embeddings = self._process_vision_input(image_input) + multimodal_embeddings += tuple(image_embeddings) if modality == "videos": video_input = modalities["videos"] - video_features = self._process_vision_input(video_input) - multimodal_embeddings += tuple(video_features) + video_embeddings = self._process_vision_input(video_input) + multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings @@ -1200,8 +1150,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: Any, ) -> torch.Tensor: if intermediate_tensors is not None: @@ -1218,7 +1168,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.llm.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -1243,7 +1193,7 @@ def init_llm( def init_vision_module( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, prefix: str = "", ) -> nn.Module: raise NotImplementedError @@ -1252,7 +1202,7 @@ def init_resampler( self, embed_dim: int, vision_dim: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: raise NotImplementedError @@ -1278,7 +1228,7 @@ def init_llm( def init_vision_module( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, prefix: str = "", ) -> nn.Module: # TODO: refactor vision model through timm wrapper from transformers @@ -1313,7 +1263,7 @@ def init_resampler( self, embed_dim: int, vision_dim: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: with set_default_torch_dtype(torch.float16): @@ -1381,7 +1331,7 @@ def init_llm( def init_vision_module( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, prefix: str = "", ) -> nn.Module: model = Idefics2VisionTransformer( @@ -1398,7 +1348,7 @@ def init_resampler( self, embed_dim: int, vision_dim: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: with set_default_torch_dtype(torch.float16): @@ -1474,7 +1424,7 @@ def init_llm( def init_vision_module( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: model = Idefics2VisionTransformer( @@ -1491,7 +1441,7 @@ def init_resampler( self, embed_dim: int, vision_dim: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: with set_default_torch_dtype(torch.float16): @@ -1562,11 +1512,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) assert self.version == (4, 0) - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): - if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)): - return None - return quant_config - def init_llm( self, vllm_config: VllmConfig, @@ -1577,10 +1522,9 @@ def init_llm( def init_vision_module( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: - quant_config = self._maybe_ignore_quant_config(quant_config) model = Idefics2VisionTransformer( config.vision_config, quant_config=quant_config, @@ -1595,10 +1539,9 @@ def init_resampler( self, embed_dim: int, vision_dim: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: - quant_config = self._maybe_ignore_quant_config(quant_config) with set_default_torch_dtype(torch.float16): # The resampler in 4.0 remains consistent with the one in 2.5/2.6. resampler = Resampler2_5( @@ -1667,11 +1610,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) assert self.version == (4, 5) - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): - if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)): - return None - return quant_config - def init_llm( self, vllm_config: VllmConfig, @@ -1682,10 +1620,9 @@ def init_llm( def init_vision_module( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: - quant_config = self._maybe_ignore_quant_config(quant_config) model = Idefics2VisionTransformer( config.vision_config, quant_config=quant_config, @@ -1700,10 +1637,9 @@ def init_resampler( self, embed_dim: int, vision_dim: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: - quant_config = self._maybe_ignore_quant_config(quant_config) with set_default_torch_dtype(torch.float16): # The resampler in 4.0 remains consistent with the one in 2.5/2.6. resampler = Resampler4_5( diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index e6e0952f71dd..e262012dcd52 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -4,14 +4,13 @@ from collections.abc import Iterable from itertools import islice -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING if TYPE_CHECKING: pass import regex as re import torch -import torch.distributed from torch import nn from transformers import MiniMaxConfig @@ -83,7 +82,7 @@ def __init__( self, hidden_size: int, intermediate_size: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, layer_idx: int = None, prefix: str = "mlp", ) -> None: @@ -121,9 +120,9 @@ def __init__( top_k: int, hidden_size: int, intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, + params_dtype: torch.dtype | None = None, layer_idx: int = None, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "moe", ) -> None: super().__init__() @@ -191,10 +190,10 @@ def __init__( rotary_dim: int, max_position: int = 4096 * 32, rope_theta: float = 10000, - sliding_window: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, + sliding_window: int | None = None, + quant_config: QuantizationConfig | None = None, layer_idx: int = None, - cache_config: Optional[CacheConfig] = None, + cache_config: CacheConfig | None = None, prefix: str = "mha", ) -> None: super().__init__() @@ -273,12 +272,12 @@ class MiniMaxText01DecoderLayer(nn.Module): def __init__( self, config: MiniMaxConfig, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, expert_num: int = 1, layer_id: int = None, - linear_layer_id: Optional[int] = None, + linear_layer_id: int | None = None, prefix: str = "decoder", ) -> None: self._ilayer = layer_id @@ -428,7 +427,7 @@ def forward( hidden_states: torch.Tensor, positions: torch.Tensor, attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, is_warmup: bool = False, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -627,12 +626,12 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -722,8 +721,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: hidden_states = self.model( diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index a25a7097a6ec..fb7c6d42a065 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal, TypeAlias import torch import torch.nn as nn @@ -52,11 +52,11 @@ class MiniMaxVL01ImagePixelInputs(TensorSchema): type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np", "h", "w"}), ] - image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] + image_sizes: Annotated[torch.Tensor | None, TensorShape("bn", 2)] # This should be in `(height, width)` format. @@ -72,9 +72,9 @@ class MiniMaxVL01ImageEmbeddingInputs(TensorSchema): data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] -MiniMaxVL01ImageInputs = Union[ - MiniMaxVL01ImagePixelInputs, MiniMaxVL01ImageEmbeddingInputs -] +MiniMaxVL01ImageInputs: TypeAlias = ( + MiniMaxVL01ImagePixelInputs | MiniMaxVL01ImageEmbeddingInputs +) class MiniMaxVL01MultiModalProjector(nn.Module): @@ -84,7 +84,7 @@ def __init__( text_hidden_size: int, projector_hidden_act: str, multimodal_projector_bias: bool, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -127,7 +127,7 @@ def get_hf_processor(self, **kwargs: object): return hf_processor - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} @@ -187,7 +187,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support } @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<image>" @@ -239,9 +239,9 @@ def get_language_model(self) -> torch.nn.Module: def _image_pixels_to_features( self, - vision_tower: Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel], - pixel_values: Union[torch.Tensor, list[torch.Tensor]], - ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + vision_tower: CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel, + pixel_values: torch.Tensor | list[torch.Tensor], + ) -> torch.Tensor | tuple[torch.Tensor, ...]: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower feature_select_strategy = self.config.vision_feature_select_strategy @@ -302,7 +302,7 @@ def pack_image_features( def _process_image_pixels( self, inputs: MiniMaxVL01ImagePixelInputs, - ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + ) -> torch.Tensor | tuple[torch.Tensor, ...]: assert self.vision_tower is not None pixel_values = inputs["pixel_values"] @@ -311,7 +311,7 @@ def _process_image_pixels( def _process_image_input( self, image_input: MiniMaxVL01ImageInputs, - ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + ) -> torch.Tensor | tuple[torch.Tensor, ...]: if image_input["type"] == "image_embeds": return image_input["data"] @@ -330,7 +330,7 @@ def _process_image_input( def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[MiniMaxVL01ImageInputs]: + ) -> MiniMaxVL01ImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) @@ -364,10 +364,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None elif inputs_embeds is None: @@ -388,7 +388,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 8e74425c5dbd..26d4deca2e12 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -3,7 +3,7 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Final, Literal, Optional, Protocol, TypeVar, Union +from typing import Annotated, Final, Literal, Protocol, TypeVar import torch import torch.nn as nn @@ -72,7 +72,7 @@ class Mistral3ImagePixelInputs(TensorSchema): # Note that `height` or `width` may be different per batch and image, # in which case the data is passed as a list instead of a batched tensor. pixel_values: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}), ] @@ -136,7 +136,7 @@ def __init__( patch_size: int, projector_hidden_act: str, multimodal_projector_bias: bool, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -179,7 +179,7 @@ class LlavaLikeConfig(Protocol): vision_config: Final[PretrainedConfig] image_token_index: Final[int] vision_feature_select_strategy: Final[str] - vision_feature_layer: Final[Union[int, list[int]]] + vision_feature_layer: Final[int | list[int]] class LlavaLikeProcessor(Protocol): @@ -197,7 +197,7 @@ def get_vision_encoder_info(self): def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor: raise NotImplementedError - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_num_image_tokens( @@ -234,7 +234,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) @@ -348,7 +348,7 @@ def _build_mistral3_processor( info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[BaseMultiModalProcessorCache] = None, + cache: BaseMultiModalProcessorCache | None = None, ) -> BaseMultiModalProcessor: assert isinstance(info, Mistral3ProcessingInfo) return Mistral3MultiModalProcessor( @@ -394,9 +394,9 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: def init_vision_tower_for_llava( hf_config: LlavaLikeConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, *, - require_post_norm: Optional[bool] = None, + require_post_norm: bool | None = None, prefix: str = "", ) -> PixtralHFVisionModel: vision_config = hf_config.vision_config @@ -441,7 +441,7 @@ class Mistral3ForConditionalGeneration( ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return None @@ -504,7 +504,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[Mistral3ImagePixelInputs]: + ) -> Mistral3ImagePixelInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -519,7 +519,7 @@ def _parse_and_validate_image_input( def _process_image_input( self, image_input: Mistral3ImagePixelInputs, - ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + ) -> torch.Tensor | tuple[torch.Tensor, ...]: if image_input["type"] == "image_embeds": return image_input["data"] @@ -562,10 +562,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: """Run forward pass for Mistral3. One key thing to understand is the `input_ids` already accounts for the @@ -615,7 +615,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 37b49349ec12..bc56481820a9 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -27,7 +27,6 @@ import typing from collections.abc import Callable, Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -88,10 +87,10 @@ def __init__( top_k: int, hidden_size: int, intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - dp_size: Optional[int] = None, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + tp_size: int | None = None, + dp_size: int | None = None, prefix: str = "", enable_eplb: bool = False, ): @@ -163,8 +162,8 @@ def __init__( num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -242,8 +241,8 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", enable_eplb: bool = False, ) -> None: @@ -280,7 +279,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> torch.Tensor: # Self Attention if residual is None: @@ -353,9 +352,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -615,9 +614,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -626,7 +625,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index b624a6200ab3..81be1135dfd9 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -19,7 +19,7 @@ import math from collections.abc import Iterable, Mapping from itertools import tee -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal import torch from torch import nn @@ -71,7 +71,7 @@ SupportsPP, ) from .llama4 import Llama4ForCausalLM -from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix +from .utils import AutoWeightsLoader, maybe_prefix from .vision import run_dp_sharded_vision_model @@ -86,7 +86,7 @@ class Llama4ImagePatchInputs(TensorSchema): type: Literal["pixel_values"] = "pixel_values" - flat_data: Annotated[ + pixel_values: Annotated[ torch.Tensor, TensorShape("total_num_chunks", "num_channels", "image_size", "image_size"), ] @@ -96,7 +96,7 @@ class Llama4ImagePatchInputs(TensorSchema): The number of total patches for each image in the batch. This is used to split the embeddings which has the first two dimensions - flattened just like `flat_data`. + flattened just like `pixel_values`. """ aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)] @@ -115,7 +115,7 @@ def __init__( output_size: int, bias: bool, output_activation: bool, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ): @@ -152,7 +152,7 @@ class Llama4MultiModalProjector(nn.Module): def __init__( self, config, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -199,7 +199,7 @@ class Llama4VisionPixelShuffleMLP(nn.Module): def __init__( self, config, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ): @@ -229,7 +229,7 @@ class Llama4VisionAttention(nn.Module): def __init__( self, config: Llama4VisionConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, prefix: str = "", use_data_parallel: bool = False, ): @@ -323,7 +323,7 @@ class Llama4VisionEncoderLayer(nn.Module): def __init__( self, config: Llama4VisionConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, prefix: str = "", use_data_parallel: bool = False, ): @@ -376,7 +376,7 @@ class Llama4VisionEncoder(nn.Module): def __init__( self, config: Llama4VisionConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, prefix: str = "", use_data_parallel: bool = False, ): @@ -419,7 +419,7 @@ class Llama4UnfoldConvolution(nn.Module): def __init__( self, config: Llama4VisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ): @@ -449,7 +449,7 @@ class Llama4VisionModel(nn.Module): def __init__( self, config: Llama4VisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ): @@ -547,7 +547,7 @@ def get_hf_processor(self, **kwargs: object) -> Llama4Processor: Llama4Processor, use_fast=kwargs.pop("use_fast", True), **kwargs ) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: # Although vLLM can support more images from an infra capability # perspective, we do not recommend using >10 images in practice. return {"image": None} @@ -699,7 +699,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) @@ -725,6 +725,8 @@ def get_dummy_mm_data( class Llama4ForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3 ): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -733,7 +735,7 @@ class Llama4ForConditionalGeneration( supports_encoder_tp_data = True @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|image|>" @@ -792,23 +794,18 @@ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[Llama4ImagePatchInputs]: + ) -> Llama4ImagePatchInputs | None: # num_images, 1, num_chunks, channel, image_size, image_size pixel_values = kwargs.pop("pixel_values", None) if pixel_values is None: return None - # num_images x num_chunks, channel, image_size, image_size - # TODO: confirm handling for variable lengths - flat_pixel_values = flatten_bn(pixel_values, concat=True) - patches_per_image = flatten_bn(kwargs.pop("patches_per_image")) + patches_per_image = kwargs.pop("patches_per_image") aspect_ratios = kwargs.pop("aspect_ratios") - if aspect_ratios.ndim == 3: - aspect_ratios = aspect_ratios.squeeze(1) return Llama4ImagePatchInputs( type="pixel_values", - flat_data=flat_pixel_values, + pixel_values=pixel_values, patches_per_image=patches_per_image, aspect_ratios=aspect_ratios, ) @@ -817,16 +814,16 @@ def _process_image_input( self, image_input: Llama4ImagePatchInputs ) -> MultiModalEmbeddings: assert self.vision_model and self.multi_modal_projector - flat_data = image_input["flat_data"] + pixel_values = image_input["pixel_values"] patches_per_image = image_input["patches_per_image"].tolist() # shard image input if self.use_data_parallel: vision_embeddings_flat = run_dp_sharded_vision_model( - flat_data, self.vision_model + pixel_values, self.vision_model ) else: - vision_embeddings_flat = self.vision_model(flat_data) + vision_embeddings_flat = self.vision_model(pixel_values) vision_embeddings_flat = self.multi_modal_projector(vision_embeddings_flat) @@ -849,10 +846,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None @@ -863,7 +860,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def separate_weights( diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 58e2acb8ce92..5a0769f3bdaa 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Set -from typing import Optional, Union import torch from torch import nn @@ -40,9 +39,12 @@ def __init__(self, config: ModernBertConfig): self.tok_embeddings = VocabParallelEmbedding( config.vocab_size, config.hidden_size ) - self.norm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps, bias=config.norm_bias + eps = ( + getattr(config, "norm_eps", None) + or getattr(config, "layer_norm_eps", None) + or 1e-5 ) + self.norm = nn.LayerNorm(config.hidden_size, eps=eps, bias=config.norm_bias) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.tok_embeddings(input_ids) @@ -50,7 +52,7 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward( self, input_ids: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if inputs_embeds is not None: return self.norm(inputs_embeds) @@ -74,7 +76,7 @@ def __init__(self, config: ModernBertConfig, head_size: int, dim: int, base: flo class ModernBertAttention(nn.Module): - def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + def __init__(self, config: ModernBertConfig, layer_id: int | None = None): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -151,7 +153,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class ModernBertLayer(nn.Module): def __init__( - self, config: ModernBertConfig, prefix: str = "", layer_id: Optional[int] = None + self, config: ModernBertConfig, prefix: str = "", layer_id: int | None = None ): super().__init__() self.config = config @@ -243,8 +245,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -287,9 +289,9 @@ def _head(self, pooled_output: torch.Tensor): def forward( self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], + hidden_states: torch.Tensor | list[torch.Tensor], pooling_metadata: PoolingMetadata, - ) -> Union[torch.Tensor, list[torch.Tensor]]: + ) -> torch.Tensor | list[torch.Tensor]: pooled_output = self.pooling(hidden_states, pooling_metadata) if isinstance(pooled_output, list): @@ -323,20 +325,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_classify": Pooler.for_token_classify( + pooler_config, classifier=self.classifier + ), "classify": ClassifierPooler( - pooling=self.pooling, - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config - ), + pooling=self.pooling, classifier=self.classifier, act_fn="classify" ), "score": ClassifierPooler( - pooling=self.pooling, - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config - ), + pooling=self.pooling, classifier=self.classifier, act_fn="score" ), } ) @@ -370,10 +366,10 @@ def weight_filter(): def forward( self, - input_ids: Optional[torch.LongTensor], + input_ids: torch.LongTensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: return self.model( input_ids=input_ids, @@ -422,7 +418,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_classify": Pooler.for_token_classify( + pooler_config=pooler_config + ), } ) @@ -436,10 +434,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: hidden_states = self.model( input_ids=input_ids, diff --git a/vllm/model_executor/models/module_mapping.py b/vllm/model_executor/models/module_mapping.py index 666796d835a3..9e7d997bdb01 100644 --- a/vllm/model_executor/models/module_mapping.py +++ b/vllm/model_executor/models/module_mapping.py @@ -5,7 +5,6 @@ # https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py from dataclasses import dataclass, field -from typing import Union @dataclass @@ -55,10 +54,10 @@ class MultiModelKeys(ModelKeys): @staticmethod def from_string_field( - language_model: Union[str, list[str]] = None, - connector: Union[str, list[str]] = None, - tower_model: Union[str, list[str]] = None, - generator: Union[str, list[str]] = None, + language_model: str | list[str] = None, + connector: str | list[str] = None, + tower_model: str | list[str] = None, + generator: str | list[str] = None, **kwargs, ) -> "MultiModelKeys": def to_list(value): diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 734841d0dc98..dce94d181c4c 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from functools import cached_property, partial from itertools import islice -from typing import Annotated, Optional, Union +from typing import Annotated import numpy as np import torch @@ -75,7 +75,6 @@ from .utils import ( AutoWeightsLoader, WeightsMapper, - flatten_bn, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -97,28 +96,19 @@ class MolmoImageInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of images - - nc: Number of crops (dynamic) + - bnc: Batch size * number of images * number of crops (dynamic) - np: Number of patches - tp: Token sequence positions - pd: Patch dimension """ - images: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"}), - ] - # Number of crops may vary per batch and image, so pass it as a list. + images: Annotated[torch.Tensor, TensorShape("bnc", "np", "pd")] - image_masks: Annotated[ - Optional[Union[torch.Tensor, list[torch.Tensor]]], - TensorShape("bn", "nc", "np", dynamic_dims={"nc"}), - ] + image_masks: Annotated[torch.Tensor | None, TensorShape("bnc", "np")] + + image_input_idx: Annotated[torch.Tensor, TensorShape("bnc", "tp")] + """An index tensor that maps image features to their corresponding patch tokens.""" - feat_is_patch: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}), - ] - # A boolean mask indicating which image features correspond to patch tokens. num_crops: Annotated[torch.Tensor, TensorShape("bn")] @@ -151,7 +141,7 @@ class ViTMLP(nn.Module): def __init__( self, config: VisionBackboneConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ): super().__init__() self.w1 = ColumnParallelLinear( @@ -185,7 +175,7 @@ def __init__( config: VisionBackboneConfig, use_bias: bool = True, nlayers: int = 1, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ): super().__init__() @@ -238,7 +228,7 @@ def __init__( ) def forward( - self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None + self, inputs_q: torch.Tensor, inputs_kv: torch.Tensor | None = None ) -> torch.Tensor: if inputs_kv is not None: inputs_k = inputs_kv @@ -263,7 +253,7 @@ class ResidualAttentionBlock(nn.Module): def __init__( self, config: VisionBackboneConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ): super().__init__() self.attention = MultiHeadDotProductAttention(config, quant_config=quant_config) @@ -289,7 +279,7 @@ class BlockCollection(nn.Module): def __init__( self, config: VisionBackboneConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ): super().__init__() self.resblocks = nn.ModuleList( @@ -317,7 +307,7 @@ class VisionTransformer(nn.Module): def __init__( self, config: VisionBackboneConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ): super().__init__() scale = config.image_emb_dim**-0.5 @@ -367,7 +357,7 @@ def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: return x def forward( - self, x: torch.Tensor, patch_num: Optional[int] = None + self, x: torch.Tensor, patch_num: int | None = None ) -> list[torch.Tensor]: """ : param x: (batch_size, num_patch, n_pixels) @@ -396,8 +386,8 @@ class MolmoAttention(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -432,9 +422,9 @@ def __init__( quant_config=quant_config, ) - self.tp_rank: Optional[int] = None - self.k_norm: Optional[nn.Module] = None - self.q_norm: Optional[nn.Module] = None + self.tp_rank: int | None = None + self.k_norm: nn.Module | None = None + self.q_norm: nn.Module | None = None if config.attention_layer_norm: self.tp_rank = get_tensor_model_parallel_rank() self.k_norm = RMSNorm( @@ -503,8 +493,8 @@ class LanguageModelMLP(nn.Module): def __init__( self, config: PretrainedConfig, - input_dim: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, + input_dim: int | None = None, + quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -542,8 +532,8 @@ class ImageProjectorMLP(nn.Module): def __init__( self, config: PretrainedConfig, - input_dim: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, + input_dim: int | None = None, + quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -580,8 +570,8 @@ class MolmoDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -604,8 +594,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]: + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: # Self Attention if residual is None: residual = hidden_states @@ -627,8 +617,8 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], - ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]: + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: # Self Attention residual = hidden_states hidden_states = self.self_attn( @@ -654,7 +644,7 @@ def __init__( self, config: PretrainedConfig, vision_config: VisionBackboneConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() self.vit_layers = VIT_LAYERS @@ -849,8 +839,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -1064,7 +1054,7 @@ def image_token_length_h(self) -> int: return image_token_length_h @property - def message_format(self) -> Optional[str]: + def message_format(self) -> str | None: return "role" @property @@ -1145,9 +1135,9 @@ def get_patches_grid_size( def __call__( self, - text: Optional[Union[TextInput, list[TextInput]]] = None, - images: Optional[Union[ImageInput, list[ImageInput]]] = None, - return_tensors: Optional[Union[str, TensorType]] = None, + text: TextInput | list[TextInput] | None = None, + images: ImageInput | list[ImageInput] | None = None, + return_tensors: str | TensorType | None = None, **kwargs, ) -> BatchFeature: outputs = self.processor.process( # type: ignore @@ -1177,7 +1167,7 @@ def __call__( num_crops = torch.tensor(tilings).prod(-1) + 1 assert num_crops.sum() == len(feat_is_patch) - outputs["feat_is_patch"] = feat_is_patch + outputs["image_input_idx"] = image_input_idx outputs["num_crops"] = num_crops outputs["img_patch_id"] = self.image_patch_id @@ -1189,7 +1179,7 @@ def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper: processor = self.ctx.get_hf_processor(**kwargs) return MolmoProcessorWrapper(processor) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_num_image_tokens( @@ -1197,7 +1187,7 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - processor: Optional[MolmoProcessorWrapper], + processor: MolmoProcessorWrapper | None, ) -> int: if processor is None: processor = self.get_hf_processor() @@ -1211,8 +1201,9 @@ def get_num_image_tokens( image_token_length_w = processor.image_token_length_w image_token_length_h = processor.image_token_length_h - extra = image_token_length_w * image_token_length_h - joint = ((ncols + 1) // pooling_size) * ((nrows + 1) // pooling_size) + # Calculate total tokens: 2 for start/end + (w+1)*h for column separators + extra = 2 + (image_token_length_w + 1) * image_token_length_h + joint = 2 + ((ncols + 1) // pooling_size + 1) * ((nrows + 1) // pooling_size) return extra + joint @@ -1249,7 +1240,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) @@ -1273,13 +1264,16 @@ def _apply_hf_processor_tokens_only( ) -> list[int]: processor = self.info.get_hf_processor() - # Apply the chat template to the tokens + # The chat template is already applied to the prompt tokens + # Use message_format="none" to avoid applying it again + # Prepend an empty space if `always_start_with_space` is True tokens = processor.processor.get_tokens_input( # type: ignore self.info.get_tokenizer().decode(prompt_tokens), - message_format=processor.message_format, + message_format="none", always_start_with_space=processor.always_start_with_space, ) + # Prepend a BOS token id to the tokens processed_data = self.info.ctx.call_hf_processor( processor, # type: ignore dict(tokens=tokens), @@ -1299,7 +1293,7 @@ def _get_mm_fields_config( return dict( images=MultiModalFieldConfig.flat_from_sizes("image", num_crops), image_masks=MultiModalFieldConfig.flat_from_sizes("image", num_crops), - feat_is_patch=MultiModalFieldConfig.flat_from_sizes("image", num_crops), + image_input_idx=MultiModalFieldConfig.flat_from_sizes("image", num_crops), num_crops=MultiModalFieldConfig.batched("image"), img_patch_id=MultiModalFieldConfig.shared("image", num_images), ) @@ -1362,6 +1356,8 @@ def get_insertion_molmo(item_idx: int): class MolmoForCausalLM( nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant ): + merge_by_field_config = True + hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ # vision backbone mapping @@ -1397,7 +1393,7 @@ class MolmoForCausalLM( } @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return None @@ -1441,32 +1437,26 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def _parse_and_validate_image_input( self, **kwargs: object, - ) -> Optional[MolmoImageInputs]: + ) -> MolmoImageInputs | None: images = kwargs.pop("images", None) image_masks = kwargs.pop("image_masks", None) - feat_is_patch = kwargs.pop("feat_is_patch", None) + image_input_idx = kwargs.pop("image_input_idx", None) num_crops = kwargs.pop("num_crops", None) if images is None: return None - if not isinstance(num_crops, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of num_crops. Got type: {type(num_crops)}" - ) - num_crops = flatten_bn(num_crops, concat=True) - img_patch_id = kwargs.pop("img_patch_id", None) - if not isinstance(img_patch_id, torch.Tensor): - raise ValueError( - f"Incorrect type of img_patch_id. Got type: {type(img_patch_id)}" - ) - self.img_patch_id = img_patch_id.flatten().unique().item() + if isinstance(img_patch_id, torch.Tensor): + img_patch_id = img_patch_id.item() + + assert isinstance(img_patch_id, int) + self.img_patch_id = img_patch_id return MolmoImageInputs( images=images, image_masks=image_masks, - feat_is_patch=feat_is_patch, + image_input_idx=image_input_idx, num_crops=num_crops, ) @@ -1476,31 +1466,28 @@ def _process_image_input( ) -> list[torch.Tensor]: images = image_input["images"] image_masks = image_input["image_masks"] - feat_is_patch = image_input["feat_is_patch"] + image_input_idx = image_input["image_input_idx"] num_crops = image_input["num_crops"] # Call the vision backbone on the whole batch at once - images_flat = flatten_bn(images, concat=True) - image_masks_flat = ( - None if image_masks is None else flatten_bn(image_masks, concat=True) - ) - feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True) - - image_features_flat = self.vision_backbone( - images=images_flat.unsqueeze(0), - image_masks=( - None if image_masks_flat is None else image_masks_flat.unsqueeze(0) - ), + image_features = self.vision_backbone( + images=images.unsqueeze(0), + image_masks=None if image_masks is None else image_masks.unsqueeze(0), ).squeeze(0) # Only the features corresponding to patch tokens are relevant - return [ - feats[f_is_patch] - for feats, f_is_patch in zip( - image_features_flat.split(num_crops.tolist()), - feat_is_patch_flat.split(num_crops.tolist()), - ) - ] + # Re-order the features using the image_input_idx tensor + results = [] + num_crops_list = num_crops.tolist() + for feats, img_idx in zip( + image_features.split(num_crops_list), + image_input_idx.split(num_crops_list), + ): + is_valid = img_idx >= 0 + valid_img_idx = img_idx[is_valid] + order = torch.argsort(valid_img_idx) + results.append(feats[is_valid][order]) + return results def get_language_model(self) -> torch.nn.Module: return self.model @@ -1516,8 +1503,8 @@ def forward( self, input_ids: torch.LongTensor, positions: torch.LongTensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor: if intermediate_tensors is not None: diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index 3bf8fce0de0d..96ec6e6b56ac 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -45,7 +45,6 @@ from collections.abc import Sequence from copy import deepcopy from functools import cached_property -from typing import Optional, Union import torch import torch.nn as nn @@ -68,8 +67,8 @@ def multihead_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - q_cu_seqlens: Optional[torch.Tensor] = None, - k_cu_seqlens: Optional[torch.Tensor] = None, + q_cu_seqlens: torch.Tensor | None = None, + k_cu_seqlens: torch.Tensor | None = None, ) -> torch.Tensor: """Multi-head attention using flash attention 2. @@ -121,8 +120,8 @@ def sdpa_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - q_cu_seqlens: Optional[torch.Tensor] = None, - k_cu_seqlens: Optional[torch.Tensor] = None, + q_cu_seqlens: torch.Tensor | None = None, + k_cu_seqlens: torch.Tensor | None = None, ) -> torch.Tensor: """SDPA attention. @@ -230,7 +229,7 @@ def __init__( self, out_dim: int, in_dim: int = 3, - patch_size: Union[int, tuple[int, int]] = (14, 14), + patch_size: int | tuple[int, int] = (14, 14), pos_emb_height: int = 14, pos_emb_width: int = 14, ): @@ -460,7 +459,7 @@ def attention_qkvpacked( self, x: torch.Tensor, cu_seqlens: torch.Tensor, - rope_freqs_cis: Optional[torch.Tensor] = None, + rope_freqs_cis: torch.Tensor | None = None, ): """ Args: @@ -491,7 +490,7 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - rope_freqs_cis: Union[torch.Tensor, None] = None, + rope_freqs_cis: torch.Tensor | None = None, ) -> torch.Tensor: """ Args: diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 3f1f2bbcb026..936dbf6c3243 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -5,7 +5,6 @@ import math from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch import torch.nn as nn @@ -58,8 +57,8 @@ class MPTAttention(nn.Module): def __init__( self, config: MptConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -152,7 +151,7 @@ class MPTMLP(nn.Module): def __init__( self, config: MptConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ): super().__init__() hidden_size = config.d_model @@ -183,8 +182,8 @@ class MPTBlock(nn.Module): def __init__( self, config: MptConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -251,9 +250,9 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -311,9 +310,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.transformer( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -322,7 +321,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 91dfa6735534..86fc1d6046ce 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -11,9 +11,10 @@ import warnings from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, Optional, TypedDict, TypeVar, Union +from typing import Annotated, Any, Literal, TypeAlias, TypeVar import numpy.typing as npt +import regex as re import torch import torch.nn as nn import torchvision.transforms as T @@ -21,7 +22,7 @@ from transformers import BatchFeature, PretrainedConfig, TensorType from vllm.config import VllmConfig -from vllm.config.multimodal import BaseDummyOptions +from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -40,7 +41,6 @@ from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM from vllm.model_executor.models.radio import RadioModel from vllm.model_executor.models.utils import ( - flatten_bn, init_vllm_registered_model, maybe_prefix, ) @@ -54,12 +54,14 @@ MultiModalFieldConfig, MultiModalKwargs, MultiModalKwargsItems, + VideoItem, ) from vllm.multimodal.parse import ( ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems, + MultiModalDataParser, ) from vllm.multimodal.processing import ( BaseMultiModalProcessor, @@ -92,43 +94,48 @@ IMG_CONTEXT = "<image>" # Profiling -MAX_FRAMES = 16 +# MAX_FRAMES = 16 DEFAULT_NUM_TILES = 12 -class NanoNemotronVLImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values_flat: torch.Tensor +class NanoNemotronVLImagePixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + Dimensions: + - bn: Batch size * number of images + - bnp: Batch size * number of images * (1 + num_patches) + - c: Number of channels (3) + - h: Height of each image patch + - w: Width of each image patch """ - num_patches: torch.Tensor - """Shape: `(batch_size * num_images)`""" + type: Literal["pixel_values"] + pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")] + num_patches: Annotated[torch.Tensor, TensorShape("bn")] -class NanoNemotronVLImageEmbeddinInputs(TypedDict): - type: Literal["image_embeds"] - data: Union[torch.Tensor, list[torch.Tensor]] - """ - A tensor of shape `(num_images, total_image_feature_size, hidden_size)` - or a list of tensors of shape `(total_image_feature_size, hidden_size)` - - `hidden_size` must match the hidden size of language model backbone. +class NanoNemotronVLImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - n: Number of images + - f: Total image feature size + - h: Hidden size (must match the hidden size of language model backbone) """ + type: Literal["image_embeds"] + data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")] -NanoNemotronVLImageInputs = Union[ - NanoNemotronVLImagePixelInputs, NanoNemotronVLImageEmbeddinInputs -] + +NanoNemotronVLImageInputs: TypeAlias = ( + NanoNemotronVLImagePixelInputs | NanoNemotronVLImageEmbeddingInputs +) class NanoNemotronVLVideoPixelInputs(TensorSchema): """ Dimensions: - bvf: Batch size * number of videos * num_frames - - bn: Batch size * number of images + - bn: Batch size * number of videos + - f: Number of frames - c: Number of channels (3) - h: Height of each video frame - w: Width of each video frame @@ -137,6 +144,8 @@ class NanoNemotronVLVideoPixelInputs(TensorSchema): type: Literal["pixel_values_videos"] pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")] num_patches: Annotated[torch.Tensor, TensorShape("bn")] + frames_indices: Annotated[torch.Tensor, TensorShape("bvf")] + frame_duration_ms: Annotated[torch.Tensor, TensorShape("bn")] class NanoNemotronVLVideoEmbeddingInputs(TensorSchema): @@ -148,12 +157,12 @@ class NanoNemotronVLVideoEmbeddingInputs(TensorSchema): """ type: Literal["video_embeds"] - data: Annotated[Union[torch.Tensor, list[torch.Tensor]], TensorShape("n", "f", "h")] + data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")] -NanoNemotronVLVideoInputs = Union[ - NanoNemotronVLVideoPixelInputs, NanoNemotronVLVideoEmbeddingInputs -] +NanoNemotronVLVideoInputs: TypeAlias = ( + NanoNemotronVLVideoPixelInputs | NanoNemotronVLVideoEmbeddingInputs +) def dynamic_preprocess( @@ -248,6 +257,21 @@ def video_to_pixel_values( return torch.stack(frames_tensors) +def input_conditioner(x, norm_mean, norm_std): + return (x - norm_mean) / norm_std + + +def calculate_timestamps( + indices: list[int] | torch.Tensor, + frame_duration_ms: int, +): + if not isinstance(indices, list): + indices = indices.tolist() + + timestamps = [int(i) * frame_duration_ms / 1000.0 for i in indices] + return timestamps + + class BaseNanoNemotronVLProcessor(ABC): """ This model doesn't define its own HF processor, @@ -262,7 +286,7 @@ def __init__( config: PretrainedConfig, tokenizer: AnyTokenizer, *args, - max_num_tiles: Optional[int] = None, + max_num_tiles: int | None = None, **kwargs, ) -> None: super().__init__() @@ -291,7 +315,7 @@ def image_token_id(self) -> int: def get_image_repl( self, feature_size: int, - num_patches: Optional[int], + num_patches: int | None, ) -> PromptUpdateDetails[str]: raise NotImplementedError @@ -341,20 +365,33 @@ def _preprocess_image( else: pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles) image_inputs = { - "pixel_values_flat": torch.cat(pixel_values_lst), + "pixel_values_flat": input_conditioner( + torch.cat(pixel_values_lst), self.norm_mean, self.norm_std + ), "image_num_patches": torch.tensor( [len(item) for item in pixel_values_lst] ), } - for pixel_values in pixel_values_lst: + assert len(text) == 1, ( + "hf_processor is called on the output of get_dummy_text, " + "which should be a single string" + ) + parts = [x for x in re.split(r"(<image>)", text[0]) if x] + assert parts.count("<image>") == len(pixel_values_lst), ( + "the number of <image> tokens in the text should be the " + "same as the number of images" + ) + + for i, pixel_values in enumerate(pixel_values_lst): num_patches = pixel_values.shape[0] feature_size = num_patches * self.num_image_token image_repl = self.get_image_repl(feature_size, num_patches) - text = [t.replace("<image>", image_repl.full, 1) for t in text] + parts[i] = parts[i].replace("<image>", image_repl.full) + text = ["".join(parts)] return text, image_inputs - def _make_batch_input(self, input_item: Optional[Union[Any, list[Any]]] = None): + def _make_batch_input(self, input_item: Any | list[Any] | None = None): if input_item is None: input_item = [] if not isinstance(input_item, list): @@ -363,10 +400,10 @@ def _make_batch_input(self, input_item: Optional[Union[Any, list[Any]]] = None): def __call__( self, - text: Optional[Union[str, list[str]]] = None, - images: Optional[Union[Image.Image, list[Image.Image]]] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - max_num_tiles: Optional[int] = None, + text: str | list[str] | None = None, + images: Image.Image | list[Image.Image] | None = None, + return_tensors: str | TensorType | None = None, + max_num_tiles: int | None = None, ) -> BatchFeature: # Use default if not provided if max_num_tiles is None: @@ -399,12 +436,12 @@ def __init__( config: PretrainedConfig, tokenizer: AnyTokenizer, *, - max_num_tiles: Optional[int] = None, - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, - video_token: Optional[str] = None, - video_pruning_rate: Optional[float] = None, + max_num_tiles: int | None = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, + video_token: str | None = None, + video_pruning_rate: float | None = None, ) -> None: super().__init__( config=config, @@ -418,12 +455,24 @@ def __init__( self.video_token = video_token self.video_pruning_rate = video_pruning_rate + # Pre-tokenize special tokens for video processing + # to avoid repeated tokenization + self._img_start_token_ids = encode_tokens( + tokenizer, IMG_START, add_special_tokens=False + ) + self._img_end_token_ids = encode_tokens( + tokenizer, IMG_END, add_special_tokens=False + ) + self._img_context_token_ids = encode_tokens( + tokenizer, IMG_CONTEXT, add_special_tokens=False + ) + @property def supports_video(self) -> bool: return self.video_token_id is not None @property - def video_token_id(self) -> Optional[int]: + def video_token_id(self) -> int | None: if self.video_token is None: return None return self.tokenizer.get_vocab().get(self.video_token, None) @@ -436,7 +485,7 @@ def _videos_to_pixel_values_lst( self, videos: list[npt.NDArray], max_num_tiles: int, - dynamic_image_size: Optional[bool] = None, + dynamic_image_size: bool | None = None, ) -> list[torch.Tensor]: return [ video_to_pixel_values( @@ -451,24 +500,43 @@ def _videos_to_pixel_values_lst( def _preprocess_video( self, text: list[str], - videos: list[npt.NDArray], + videos: list[tuple[npt.NDArray, dict[str, Any]]], max_num_tiles: int, - dynamic_image_size: Optional[bool] = None, + dynamic_image_size: bool | None = None, ): if len(videos) == 0 or not self.supports_video: video_inputs = {} else: + videos_lst = [v[0] for v in videos] + video_metadata_lst = [v[1] for v in videos] pixel_values_lst_video = self._videos_to_pixel_values_lst( - videos, + videos_lst, max_num_tiles=max_num_tiles, dynamic_image_size=dynamic_image_size, ) + # We use frame duration in milliseconds (as integer) to ensure + # we have consistent timestamps calculation. At preprocessing + # fps parameter is given in fp32, while at inference it is bf16 + # which leads to inaccurate timestamp calculation and causes + # timestamp values to differ.In rare cases this causes + # mismatching number of output tokens for tokenized frame prefixes + frame_duration_ms_lst = [ + int(1000.0 / metadata["fps"]) for metadata in video_metadata_lst + ] + frames_indices_lst = [ + metadata["frames_indices"] for metadata in video_metadata_lst + ] + video_inputs = { - "pixel_values_flat_video": torch.cat(pixel_values_lst_video), + "pixel_values_flat_video": input_conditioner( + torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std + ), "video_num_patches": torch.tensor( [len(item) for item in pixel_values_lst_video] ), + "frames_indices": frames_indices_lst, + "frame_duration_ms": torch.tensor(frame_duration_ms_lst), } image_size: int = self.config.force_image_size @@ -478,7 +546,12 @@ def _preprocess_video( (image_size * image_size // patch_size**2) * (downsample_ratio**2) ) - for pixel_values in pixel_values_lst_video: + for pixel_values, video_metadata, frames_indices, frame_duration_ms in zip( + pixel_values_lst_video, + video_metadata_lst, + frames_indices_lst, + frame_duration_ms_lst, + ): num_frames = pixel_values.shape[0] if ( @@ -501,19 +574,32 @@ def _preprocess_video( else: tokens_per_frame = [tokens_in_single_frame] * num_frames - video_repl = self.get_video_repl(tokens_per_frame, self.video_token) + video_repl = self.get_video_repl( + tokens_per_frame=tokens_per_frame, + frames_indices=frames_indices, + frame_duration_ms=frame_duration_ms, + tokenizer=self.tokenizer, + img_start_token_ids=self._img_start_token_ids, + img_end_token_ids=self._img_end_token_ids, + img_context_token_ids=self._img_context_token_ids, + ) - text = [t.replace("<video>", video_repl.full, 1) for t in text] + # video_repl.full is a list of token IDs + # Convert token IDs back to text for the HF processor flow + video_repl_text = self.tokenizer.decode( + video_repl.full, skip_special_tokens=False + ) + text = [t.replace("<video>", video_repl_text, 1) for t in text] return text, video_inputs def __call__( self, - text: Optional[Union[str, list[str]]] = None, - images: Optional[Union[Image.Image, list[Image.Image]]] = None, - videos: Optional[Union[npt.NDArray, list[npt.NDArray]]] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - max_num_tiles: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, + text: str | list[str] | None = None, + images: Image.Image | list[Image.Image] | None = None, + videos: list[tuple[npt.NDArray, dict[str, Any]]] | None = None, + return_tensors: str | TensorType | None = None, + max_num_tiles: int | None = None, + dynamic_image_size: bool | None = None, ) -> BatchFeature: # Use default if not provided if max_num_tiles is None: @@ -545,7 +631,7 @@ def __call__( def get_image_repl( self, feature_size: int, - num_patches: Optional[int], + num_patches: int | None, ) -> PromptUpdateDetails[str]: repl_features = IMG_CONTEXT * feature_size repl_full = IMG_START + repl_features + IMG_END @@ -555,9 +641,15 @@ def get_image_repl( @classmethod def get_video_repl( cls, + *, tokens_per_frame: list[int], - video_context_token: str = IMG_CONTEXT, - ) -> PromptUpdateDetails[str]: + frames_indices: list[int], + frame_duration_ms: int, + tokenizer: AnyTokenizer, + img_start_token_ids: list[int], + img_end_token_ids: list[int], + img_context_token_ids: list[int], + ) -> PromptUpdateDetails[list[int]]: """ Build prompt replacement for a video. The replacement returned is not actually used to replace the placeholder @@ -576,16 +668,52 @@ def get_video_repl( - EVS real (called from get_real_video_repl_for_evs) - different value per frame Args: tokens_per_frame (list[int]): number of tokens per frame - video_context_token (str): the token to use for the video context + frames_indices (list[int]): frame indices + frame_duration_ms (int): duration of each frame in milliseconds + tokenizer (AnyTokenizer): tokenizer to use for tokenizing frame separators + img_start_token_ids (list[int]): pre-tokenized IMG_START tokens + img_end_token_ids (list[int]): pre-tokenized IMG_END tokens + img_context_token_ids (list[int]): pre-tokenized IMG_CONTEXT tokens """ - repl_full = "".join( - [ - f"Frame{i + 1}: {IMG_START}{video_context_token * num_tokens}{IMG_END}" - for i, num_tokens in enumerate(tokens_per_frame) + # TODO: Add support of frame_duration_ms to be None + # At preprocessing step we should allow absent / metadata without + # frames_indices field. + timestamps_enabled = frame_duration_ms is not None + + if timestamps_enabled: + timestamps = calculate_timestamps(frames_indices, frame_duration_ms) + + assert len(timestamps) == len(tokens_per_frame), ( + "timestamps and tokens_per_frame must have the same length" + ) + frame_separators = [ + f"Frame {i + 1} sampled at {timestamp:.2f} seconds: " + for i, timestamp in enumerate(timestamps) ] - ) + else: + frame_separators = [ + f"Frame {i + 1}: " for i, _ in enumerate(tokens_per_frame) + ] + + # Tokenize frame separator independently + frame_separators_tokenized = [ + _seq2tokens(tokenizer, sep) for sep in frame_separators + ] - return PromptUpdateDetails.from_seq(repl_full) + # Tokenize each component independently to avoid tokenizer merging tokens + # across boundaries. This ensures consistent tokenization regardless of + # num_tokens_per_frame values. + all_token_ids = [] + for i, num_tokens in enumerate(tokens_per_frame): + frame_sep_token_ids = frame_separators_tokenized[i] + all_token_ids.extend(frame_sep_token_ids) + + # Add pre-tokenized special tokens + all_token_ids.extend(img_start_token_ids) + all_token_ids.extend(img_context_token_ids * num_tokens) + all_token_ids.extend(img_end_token_ids) + + return PromptUpdateDetails.from_seq(all_token_ids) class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo): @@ -598,7 +726,7 @@ def get_hf_processor( ) -> BaseNanoNemotronVLProcessor: raise NotImplementedError - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_num_image_tokens( @@ -607,7 +735,7 @@ def get_num_image_tokens( image_width: int, image_height: int, max_num_tiles: int, - processor: Optional[BaseNanoNemotronVLProcessor], + processor: BaseNanoNemotronVLProcessor | None, ) -> int: if processor is None: processor = self.get_hf_processor() @@ -673,10 +801,10 @@ def get_supported_mm_limits(self): video_limit = {"video": None} if self.supports_video else {} return {**super().get_supported_mm_limits(), **video_limit} - def get_video_token(self) -> Optional[str]: + def get_video_token(self) -> str | None: return IMG_CONTEXT - def get_video_pruning_rate(self) -> Optional[float]: + def get_video_pruning_rate(self) -> float | None: return self.ctx.get_mm_config().video_pruning_rate def get_num_frames_with_most_features( @@ -692,8 +820,6 @@ def get_num_frames_with_most_features( max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token max_frames_per_video = max_total_frames // max(max_videos, 1) - - max_frames_per_video = min(max_frames_per_video, MAX_FRAMES) return max(max_frames_per_video, 1) def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor: @@ -710,37 +836,12 @@ def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor: class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]): """Basic image-only MultiModalProcessor for InternVL-style models.""" - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> BatchFeature: - processed_outputs = super()._call_hf_processor( - prompt=prompt, - mm_data=mm_data, - mm_kwargs=mm_kwargs, - tok_kwargs=tok_kwargs, - ) - - hf_processor = self.info.get_hf_processor(**mm_kwargs) - image_token_id = hf_processor.image_token_id - - # Since there may be extra tokens in the feature placeholders, - # we need to pass the image token ID to the model to select the - # tokens to merge from the vision encoder outputs - processed_outputs["image_token_id"] = torch.tensor(image_token_id) - - return processed_outputs - def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) - num_images = len(image_num_patches) return dict( pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( @@ -748,7 +849,6 @@ def _get_mm_fields_config( ), image_num_patches=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), - image_token_id=MultiModalFieldConfig.shared("image", num_images), ) def _get_prompt_updates( @@ -814,24 +914,8 @@ class NanoNemotronVLMultiModalProcessor( ): """MultiModalProcessor extended for video support""" - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> BatchFeature: - processed_outputs = super()._call_hf_processor( - prompt, mm_data, mm_kwargs, tok_kwargs - ) - - hf_processor = self.info.get_hf_processor(**mm_kwargs) - if ( - self.info.supports_video - and (video_token_id := hf_processor.video_token_id) is not None - ): - processed_outputs["video_token_id"] = torch.tensor(video_token_id) - return processed_outputs + def _get_data_parser(self) -> MultiModalDataParser: + return MultiModalDataParser(video_needs_metadata=True) def _get_mm_fields_config( self, @@ -841,13 +925,14 @@ def _get_mm_fields_config( image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs) if self.info.supports_video: video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0)) - num_videos = len(video_num_patches) + video_fields = dict( pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes( "video", video_num_patches ), video_num_patches=MultiModalFieldConfig.batched("video"), - video_token_id=MultiModalFieldConfig.shared("video", num_videos), + frames_indices=MultiModalFieldConfig.batched("video"), + frame_duration_ms=MultiModalFieldConfig.batched("video"), ) else: video_fields = {} @@ -878,6 +963,7 @@ def _get_prompt_updates( def get_video_replacement_internvl(item_idx: int): feature_size = hf_processor.num_image_token + video, metadata = mm_items["video"][item_idx] num_patches = video_num_patches[item_idx] if num_patches is not None: assert isinstance(num_patches, int) @@ -899,9 +985,15 @@ def get_video_replacement_internvl(item_idx: int): else: tokens_per_frame = [feature_size] * num_patches + frame_duration_ms = int(1000 / metadata["fps"]) return hf_processor.get_video_repl( - tokens_per_frame, - video_context_token=hf_processor.video_token, + tokens_per_frame=tokens_per_frame, + frames_indices=metadata["frames_indices"], + frame_duration_ms=frame_duration_ms, + tokenizer=hf_processor.tokenizer, + img_start_token_ids=hf_processor._img_start_token_ids, + img_end_token_ids=hf_processor._img_end_token_ids, + img_context_token_ids=hf_processor._img_context_token_ids, ) if self.info.supports_video: @@ -929,7 +1021,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: # Use default max_num_tiles for dummy data generation max_num_tiles = 12 @@ -960,11 +1052,42 @@ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return super().get_dummy_text(mm_counts) + "<video>" * num_videos + def _get_dummy_videos( + self, + *, + width: int, + height: int, + num_frames: int, + num_videos: int, + overrides: VideoDummyOptions | None = None, + ) -> list[VideoItem]: + video = super()._get_dummy_videos( + width=width, + height=height, + num_frames=num_frames, + num_videos=1, + overrides=overrides, + )[0] + video_items = [] + for _ in range(num_videos): + video_metadata = { + "total_num_frames": num_frames, + "fps": 2, + "duration": num_frames / 2.0, + "video_backend": "opencv_dynamic", + "frames_indices": [i for i in range(num_frames)], + "do_sample_frames": False, + } + video_item = (video.copy(), video_metadata) + video_items.append(video_item) + + return video_items + def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: dummy_image = super().get_dummy_mm_data( seq_len=seq_len, mm_counts=mm_counts, mm_options=mm_options @@ -999,8 +1122,10 @@ def get_dummy_mm_data( class NemotronH_Nano_VL_V2( nn.Module, HasInnerState, IsHybrid, SupportsMultiModal, SupportsMultiModalPruning ): + merge_by_field_config = True + @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<image>" if modality.startswith("video"): @@ -1028,7 +1153,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "language_model"), ) self.vision_model = self.get_vit_model_from_radio_config(config).to( - self.language_model.config.torch_dtype + self.language_model.config.dtype ) # Construct the vision projection. @@ -1049,13 +1174,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ReLUSquaredActivation(), nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False), ) - self.mlp1 = self.mlp1.to(self.language_model.config.torch_dtype) + self.mlp1 = self.mlp1.to(self.language_model.config.dtype) - self.img_context_token_id = None - self.video_context_token_id = None self.config = config self.model_config = vllm_config.model_config + # Pre-tokenize special tokens for video processing + # to avoid repeated tokenization + tokenizer = cached_tokenizer_from_config(vllm_config.model_config) + self._img_start_token_ids = encode_tokens( + tokenizer, IMG_START, add_special_tokens=False + ) + self._img_end_token_ids = encode_tokens( + tokenizer, IMG_END, add_special_tokens=False + ) + self._img_context_token_ids = encode_tokens( + tokenizer, IMG_CONTEXT, add_special_tokens=False + ) + def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale @@ -1086,18 +1222,33 @@ def pixel_shuffle(self, x, scale_factor=0.5): return x def extract_feature(self, pixel_values): - vit_embeds = self.vision_model(pixel_values) - vit_embeds = vit_embeds.to(dtype=torch.bfloat16) - h = w = int(vit_embeds.shape[1] ** 0.5) - vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) - vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) - vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) - vit_embeds = self.mlp1(vit_embeds) + # Process images in a micro-batch of at most 128 frames per call + # This is done on purpose to ensure peak GPU ram usage of huge batch + # (namely for really long videos with EVS ON) won't cause any problems + # as we don't support chunked prefill for video media + micro_batch_size = 128 + n = pixel_values.shape[0] + vit_embeds_list = [] + for i in range(0, n, micro_batch_size): + vit_embeds = self.vision_model(pixel_values[i : i + micro_batch_size]) + vit_embeds = vit_embeds.to(dtype=torch.bfloat16) + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle( + vit_embeds, scale_factor=self.downsample_ratio + ) + vit_embeds = vit_embeds.reshape( + vit_embeds.shape[0], -1, vit_embeds.shape[-1] + ) + vit_embeds = self.mlp1(vit_embeds) + vit_embeds_list.append(vit_embeds) + + vit_embeds = torch.cat(vit_embeds_list, dim=0) return vit_embeds def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[NanoNemotronVLImageInputs]: + ) -> NanoNemotronVLImageInputs | None: pixel_values_flat = kwargs.pop("pixel_values_flat", None) image_num_patches = kwargs.pop("image_num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -1106,37 +1257,12 @@ def _parse_and_validate_image_input( return None if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}" - ) - - return NanoNemotronVLImageEmbeddinInputs( + return NanoNemotronVLImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) - image_token_id = kwargs["image_token_id"] - assert isinstance(image_token_id, torch.Tensor) - self.img_context_token_id = image_token_id.flatten().unique().item() - if pixel_values_flat is not None: - if not isinstance(pixel_values_flat, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat)}" - ) - - if not isinstance(image_num_patches, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of image_num_patches. " - f"Got type: {type(image_num_patches)}" - ) - - pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) - image_num_patches = flatten_bn(image_num_patches, concat=True) - return NanoNemotronVLImagePixelInputs( type="pixel_values", pixel_values_flat=pixel_values_flat, @@ -1185,12 +1311,15 @@ def _process_video_input( rows = int(image_rows * downsample_ratio // patch_size) cols = int(image_cols * downsample_ratio // patch_size) video_pruning_rate = self.video_pruning_rate - + video_num_frames = video_input["num_patches"].tolist() + video_frames_indices = video_input["frames_indices"].split(video_num_frames) # Calculate video feature dimensions (number of frames and # their feature size (AKA tokens per frame)) # TODO: Maybe this can be optimized to avoid the loop? for i, single_video_embeddings in enumerate(video_embeddings): - num_frames = video_input["num_patches"][i].item() + num_frames = video_num_frames[i] + frames_indices = video_frames_indices[i].tolist() + frame_duration_ms = video_input["frame_duration_ms"][i].item() assert single_video_embeddings.shape[0] % num_frames == 0 if video_pruning_rate is not None and video_pruning_rate > 0.0: @@ -1219,6 +1348,8 @@ def _process_video_input( self._create_final_video_embeddings( single_video_embeddings, num_tokens_per_frame, + frames_indices, + frame_duration_ms, ), ) @@ -1228,6 +1359,8 @@ def _create_final_video_embeddings( self, video_embeddings: torch.Tensor, num_tokens_per_frame: list[int], + frames_indices: list[int], + frame_duration_ms: int, ) -> torch.Tensor: """Create final embeddings that combine video embeddings with text embeddings of indicator tokens. @@ -1241,23 +1374,28 @@ def _create_final_video_embeddings( input_embeds for the LLM. """ device = video_embeddings.device - - # Generate video replacement text and convert to token IDs - video_repl_text = NanoNemotronVLProcessor.get_video_repl( - num_tokens_per_frame, - IMG_CONTEXT, - ).full - tokenizer = cached_tokenizer_from_config(self.model_config) - repl_token_ids = torch.tensor( - _seq2tokens(tokenizer, video_repl_text), device=device - ) - # Get embedding token IDs for image context - embed_token_ids = torch.tensor( - encode_tokens(tokenizer, IMG_CONTEXT), device=device + # Generate video replacement token IDs using get_video_repl + # This tokenizes each frame separator independently, then uses pre-tokenized + # special tokens to ensure consistent tokenization regardless of + # num_tokens_per_frame values. + video_repl = NanoNemotronVLProcessor.get_video_repl( + tokens_per_frame=num_tokens_per_frame, + frames_indices=frames_indices, + frame_duration_ms=frame_duration_ms, + tokenizer=tokenizer, + img_start_token_ids=self._img_start_token_ids, + img_end_token_ids=self._img_end_token_ids, + img_context_token_ids=self._img_context_token_ids, ) + # video_repl.full is a list of token IDs + repl_token_ids = torch.tensor(video_repl.full, device=device) + + # Get embedding token IDs for image context (use pre-tokenized version) + embed_token_ids = torch.tensor(self._img_context_token_ids, device=device) + # Create mask for video embedding positions is_video_embed = torch.isin(repl_token_ids, embed_token_ids) @@ -1274,10 +1412,12 @@ def _create_final_video_embeddings( def _parse_and_validate_video_input( self, **kwargs: object - ) -> Optional[NanoNemotronVLVideoPixelInputs]: + ) -> NanoNemotronVLVideoPixelInputs | None: pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None) video_num_patches = kwargs.pop("video_num_patches", None) video_embeds = kwargs.pop("video_embeds", None) + frames_indices = kwargs.pop("frames_indices", None) + frame_duration_ms = kwargs.pop("frame_duration_ms", None) if pixel_values_flat_video is None and video_embeds is None: return None @@ -1285,35 +1425,26 @@ def _parse_and_validate_video_input( if video_embeds is not None: return NanoNemotronVLVideoEmbeddingInputs( type="video_embeds", - data=flatten_bn(video_embeds), + data=video_embeds, ) - video_token_id = kwargs["video_token_id"] - assert isinstance(video_token_id, torch.Tensor) - self.video_context_token_id = video_token_id.flatten().unique().item() - if pixel_values_flat_video is not None: - if not isinstance(pixel_values_flat_video, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat_video)}" - ) - - if not isinstance(video_num_patches, (torch.Tensor, list)): - raise ValueError( - "Incorrect type of image_num_patches. " - f"Got type: {type(video_num_patches)}" - ) + if torch.is_tensor(frames_indices): + frames_indices = frames_indices.flatten() + else: + frames_indices = torch.cat([f.flatten() for f in frames_indices], dim=0) - pixel_values_flat_video = flatten_bn(pixel_values_flat_video, concat=True) - video_num_patches = flatten_bn(video_num_patches, concat=True) + frame_duration_ms = frame_duration_ms.flatten() expected_h = expected_w = self.config.force_image_size - resolve_bindings = {"h": expected_h, "w": expected_w} + num_frames = video_num_patches[0].item() + resolve_bindings = {"h": expected_h, "w": expected_w, "f": num_frames} return NanoNemotronVLVideoPixelInputs( type="pixel_values_videos", pixel_values_flat=pixel_values_flat_video, num_patches=video_num_patches, + frames_indices=frames_indices, + frame_duration_ms=frame_duration_ms, resolve_bindings=resolve_bindings, ) @@ -1349,12 +1480,12 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: for modality in modalities: if modality == "images": image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) - multimodal_embeddings += vision_embeddings + image_embeddings = self._process_image_input(image_input) + multimodal_embeddings += tuple(image_embeddings) if modality == "videos": video_input = modalities["videos"] video_embeddings = self._process_video_input(video_input) - multimodal_embeddings += video_embeddings + multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings @@ -1365,10 +1496,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: input_ids = None inputs_embeds = None @@ -1396,7 +1527,7 @@ def get_mm_mapping(self) -> MultiModelKeys: def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 8f07a2cf12f7..845798b18d1b 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -26,7 +26,7 @@ from collections.abc import Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -85,7 +85,7 @@ def _cast_if_autocast_enabled(*args): class NemotronLayerNorm1P(nn.LayerNorm): def __init__( self, - normalized_shape: Union[int, list[int], torch.Size], + normalized_shape: int | list[int] | torch.Size, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True, @@ -97,7 +97,7 @@ def __init__( def forward( self, x: torch.Tensor, - residual: Optional[torch.Tensor] = None, + residual: torch.Tensor | None = None, ) -> torch.Tensor: if residual is not None: x = x + residual @@ -116,7 +116,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, prefix: str = "", ) -> None: @@ -152,11 +152,11 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, - cache_config: Optional[CacheConfig] = None, + cache_config: CacheConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -238,8 +238,8 @@ class NemotronDecoderLayer(nn.Module): def __init__( self, config: NemotronConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -292,7 +292,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -363,11 +363,11 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -514,9 +514,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: model_output = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -525,7 +525,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 0a05c63a31ea..f31579e5cfa8 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -18,8 +18,8 @@ # limitations under the License. """Inference-only NemotronH model.""" -from collections.abc import Iterable -from typing import Optional +import typing +from collections.abc import Callable, Iterable import torch from torch import nn @@ -27,13 +27,18 @@ from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.config.parallel import ParallelConfig +from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size +from vllm.distributed.communication_op import tensor_model_parallel_all_gather from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.activation import ReLUSquaredActivation +from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE +from vllm.model_executor.layers.fused_moe.utils import activation_without_mul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -55,6 +60,7 @@ from vllm.model_executor.models.interfaces import ( HasInnerState, IsHybrid, + MixtureOfExperts, SupportsLoRA, SupportsPP, SupportsQuant, @@ -62,9 +68,11 @@ from vllm.model_executor.models.utils import ( AutoWeightsLoader, WeightsMapper, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, + sequence_parallel_chunk, ) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronHConfig @@ -74,28 +82,21 @@ class NemotronHMLP(nn.Module): def __init__( self, config: NemotronHConfig, - layer_idx: int, - quant_config: Optional[QuantizationConfig] = None, + intermediate_size: int, + quant_config: QuantizationConfig | None = None, bias: bool = False, + reduce_results: bool = True, + is_sequence_parallel: bool = False, prefix: str = "", ) -> None: super().__init__() - hybrid_override_pattern = config.hybrid_override_pattern - mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1 - if isinstance(config.intermediate_size, list): - if len(config.intermediate_size) == 1: - intermediate_size = config.intermediate_size[0] - else: - intermediate_size = config.intermediate_size[mlp_index] - else: - intermediate_size = config.intermediate_size - self.up_proj = ColumnParallelLinear( input_size=config.hidden_size, output_size=intermediate_size, bias=bias, quant_config=quant_config, + disable_tp=is_sequence_parallel, prefix=f"{prefix}.up_proj", ) self.down_proj = RowParallelLinear( @@ -103,6 +104,8 @@ def __init__( output_size=config.hidden_size, bias=bias, quant_config=quant_config, + reduce_results=reduce_results, + disable_tp=is_sequence_parallel, prefix=f"{prefix}.down_proj", ) self.act_fn = ReLUSquaredActivation() @@ -114,33 +117,207 @@ def forward(self, x: torch.Tensor): return x +class NemotronHMoE(nn.Module): + def __init__( + self, + config: NemotronHConfig, + quant_config: QuantizationConfig | None = None, + parallel_config: ParallelConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.n_routed_experts + self.n_shared_experts: int = config.n_shared_experts + + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + + self.gate = ReplicatedLinear( + config.hidden_size, + config.n_routed_experts, + bias=False, + params_dtype=torch.float32, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts, dtype=torch.float32) + ) + # Load balancing settings. + self.enable_eplb = parallel_config.enable_eplb + + self.n_redundant_experts = parallel_config.eplb_config.num_redundant_experts # noqa: E501 + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + if config.n_shared_experts is None or config.n_shared_experts == 0: + self.shared_experts = None + else: + intermediate_size = ( + config.moe_shared_expert_intermediate_size * config.n_shared_experts + ) + + self.shared_experts = NemotronHMLP( + config=config, + intermediate_size=intermediate_size, + quant_config=quant_config, + reduce_results=False, + is_sequence_parallel=self.is_sequence_parallel, + prefix=f"{prefix}.shared_experts", + ) + + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func="sigmoid", + e_score_correction_bias=self.gate.e_score_correction_bias, + activation=activation_without_mul(config.mlp_hidden_act), + is_act_and_mul=False, # non-gated MoE + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) + + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + shared_output, final_hidden_states = fused_moe_out + + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + if hidden_states.dtype != torch.float16: + final_hidden_states *= self.routed_scaling_factor + elif self.shared_experts is not None: + assert shared_output is not None + shared_output *= 1.0 / self.routed_scaling_factor + + if self.shared_experts is not None: + assert shared_output is not None + final_hidden_states += shared_output + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0 + ) + final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) + + return final_hidden_states.view(num_tokens, hidden_dim) + + class NemotronHMLPDecoderLayer(nn.Module): def __init__( self, config: NemotronHConfig, layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + parallel_config: ParallelConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.config = config + hybrid_override_pattern = config.hybrid_override_pattern + mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1 + if isinstance(config.intermediate_size, list): + if len(config.intermediate_size) == 1: + intermediate_size = config.intermediate_size[0] + else: + intermediate_size = config.intermediate_size[mlp_index] + else: + intermediate_size = config.intermediate_size + self.mixer = NemotronHMLP( config, + intermediate_size=intermediate_size, quant_config=quant_config, bias=config.mlp_bias, prefix=f"{prefix}.mixer", - layer_idx=layer_idx, ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( self, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + + hidden_states = self.mixer(hidden_states) + return hidden_states, residual + + +class NemotronHMoEDecoderLayer(nn.Module): + def __init__( + self, + config: NemotronHConfig, + layer_idx: int, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + parallel_config: ParallelConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + self.mixer = NemotronHMoE( + config, + quant_config=quant_config, + parallel_config=parallel_config, + prefix=f"{prefix}.mixer", + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, **kwargs, ): if residual is None: @@ -158,9 +335,10 @@ def __init__( self, config: NemotronHConfig, layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + parallel_config: ParallelConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -175,7 +353,7 @@ def __init__( n_groups=config.n_groups, num_heads=config.mamba_num_heads, head_dim=config.mamba_head_dim, - rms_norm_eps=config.rms_norm_eps, + rms_norm_eps=config.layer_norm_epsilon, activation=config.mamba_hidden_act, model_config=model_config, cache_config=cache_config, @@ -183,12 +361,12 @@ def __init__( prefix=f"{prefix}.mixer", ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( self, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, **kwargs, ): if residual is None: @@ -207,9 +385,9 @@ def __init__( self, config: NemotronHConfig, layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -279,9 +457,10 @@ def __init__( self, config: NemotronHConfig, layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + parallel_config: ParallelConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -295,13 +474,13 @@ def __init__( prefix=f"{prefix}.mixer", ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, **kwargs, ): if residual is None: @@ -318,6 +497,7 @@ def forward( "M": NemotronHMambaDecoderLayer, "-": NemotronHMLPDecoderLayer, "*": NemotronHAttentionDecoderLayer, + "E": NemotronHMoEDecoderLayer, } @@ -330,6 +510,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config lora_config = vllm_config.lora_config self.config = config @@ -347,17 +528,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=config.vocab_size, ) + self.has_moe = "E" in config.hybrid_override_pattern + def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) layer_class = ALL_DECODER_LAYER_TYPES[ config.hybrid_override_pattern[layer_idx] ] return layer_class( - config, - layer_idx, - model_config, - cache_config, + config=config, + layer_idx=layer_idx, + model_config=model_config, + cache_config=cache_config, quant_config=quant_config, + parallel_config=parallel_config, prefix=prefix, ) @@ -368,7 +552,7 @@ def get_layer(prefix: str): ["hidden_states", "residual"], config.hidden_size ) - self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -377,8 +561,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -414,6 +598,22 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ("qkv_proj", "v_proj", "v"), ] + if self.has_moe: + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + # - FusedMoe.w1 (aka gate_proj) should be up_proj since that's + # what the activation is applied to + # - FusedMoe.w3 (aka up_proj) should be ignored since we're + # using non-gated MoE + ckpt_gate_proj_name="up_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="", + num_experts=self.config.n_routed_experts, + num_redundant_experts=getattr(self, "num_redundant_experts", 0), + ) + else: + expert_params_mapping = [] + params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -439,16 +639,62 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # load other params else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + name = name_mapped + break + else: + if is_expert_weight: + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class NemotronHForCausalLM( - nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant + nn.Module, + HasInnerState, + SupportsLoRA, + SupportsPP, + IsHybrid, + SupportsQuant, + MixtureOfExperts, ): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={"backbone": "model"}, @@ -546,6 +792,61 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intmd_tensors = self.model.make_empty_intmd_tensors + # Set MoE hyperparameters + if self.model.has_moe: + self.expert_weights = [] + self.num_expert_groups = config.n_group + + self.moe_layers: list[SharedFusedMoE] = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, NemotronHMoEDecoderLayer): + # Pick last one layer since the first ones + # may be dense layers. + example_moe = layer.mixer + self.moe_layers.append(layer.mixer.experts) + + self.num_moe_layers = len(self.moe_layers) + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts # noqa: E501 + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.model.layers: + if isinstance(layer, NemotronHMoEDecoderLayer): + moe = layer.mixer + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -553,8 +854,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, ): hidden_states = self.model( @@ -566,7 +867,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index ddd623b5de23..17e009612df4 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -26,7 +26,7 @@ from collections.abc import Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -84,12 +84,12 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, bias_o_proj: bool = False, - cache_config: Optional[CacheConfig] = None, + cache_config: CacheConfig | None = None, prefix: str = "", attn_type: str = AttentionType.DECODER, ) -> None: @@ -112,8 +112,8 @@ def __init__( def _init_rotary_emb( self, config, - rope_scaling: Optional[dict[str, Any]], - quant_config: Optional[QuantizationConfig], + rope_scaling: dict[str, Any] | None, + quant_config: QuantizationConfig | None, ) -> None: # Enables YARN for Mistral and LLaMA4 derivatives. is_neox_style = True @@ -139,8 +139,8 @@ def __init__( self, config: LlamaConfig, layer_idx: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -210,7 +210,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -303,11 +303,11 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -487,9 +487,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: model_output = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -498,7 +498,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index 268644bc9249..2f78e2f60c93 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -9,7 +9,6 @@ # -------------------------------------------------------- from abc import ABC from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn @@ -207,9 +206,9 @@ def __init__( tokenizer: AnyTokenizer, image_processor: BaseImageProcessorFast, *, - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, ) -> None: ABC.__init__(self) self.config = config @@ -266,9 +265,9 @@ def get_num_image_tokens( def _images_to_pixel_values_lst( self, images: list[Image.Image], - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, ) -> list[torch.Tensor]: min_num, max_num = self.resolve_min_max_num( min_dynamic_patch=min_dynamic_patch, @@ -292,9 +291,9 @@ def _preprocess_image( self, text: list[str], images: list[Image.Image], - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, ) -> tuple[list[str], dict[str, torch.Tensor]]: if len(images) == 0: image_inputs = {} @@ -326,7 +325,7 @@ def _preprocess_image( def get_image_repl( self, feature_size: int, - num_patches: Optional[int], + num_patches: int | None, ) -> PromptUpdateDetails[str]: repl_features = IMG_CONTEXT * feature_size repl_full = IMG_START + repl_features + IMG_END @@ -362,7 +361,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor merge_by_field_config = True @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<image>" @@ -426,7 +425,7 @@ def _patch_quant_config( def _init_vision_model( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, *, prefix: str, ): @@ -482,7 +481,7 @@ def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[InternVLImageInputs]: + ) -> InternVLImageInputs | None: pixel_values_flat = kwargs.pop("pixel_values_flat", None) image_num_patches = kwargs.pop("image_num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -497,8 +496,11 @@ def _parse_and_validate_image_input( ) image_token_id = kwargs["image_token_id"] - assert isinstance(image_token_id, torch.Tensor) - self.img_context_token_id = image_token_id.flatten().unique().item() + if isinstance(image_token_id, torch.Tensor): + image_token_id = image_token_id.flatten().unique().item() + + assert isinstance(image_token_id, int) + self.img_context_token_id = image_token_id if pixel_values_flat is not None: return InternVLImagePixelInputs( @@ -573,17 +575,17 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: for modality in modalities: if modality == "images": image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) - multimodal_embeddings += vision_embeddings + image_embeddings = self._process_image_input(image_input) + multimodal_embeddings += tuple(image_embeddings) return multimodal_embeddings def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, *, - is_multimodal: Optional[torch.Tensor] = None, + is_multimodal: torch.Tensor | None = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: @@ -604,8 +606,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> IntermediateTensors: if intermediate_tensors is not None: @@ -630,7 +632,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index f17bf3b09d5b..73dd8dfd0f85 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -8,7 +8,6 @@ # Licensed under Apache 2.0 License [see LICENSE for details] # -------------------------------------------------------- from collections.abc import Mapping, Sequence -from typing import Optional import torch import torch.nn as nn @@ -49,7 +48,7 @@ def image_token_id(self) -> int: def get_image_repl( self, feature_size: int, - num_patches: Optional[int], + num_patches: int | None, ) -> PromptUpdateDetails[str]: if num_patches is None: raise NotImplementedError("Embedding inputs are not supported") @@ -93,7 +92,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) @@ -189,7 +188,7 @@ def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: def _init_vision_model( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, *, is_mono: bool, prefix: str, diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index f334bbf9feeb..390a91d3425c 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -26,7 +26,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -65,15 +64,15 @@ class OlmoAttention(nn.Module): """ This is the attention block where the output is computed as - ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + `Attention(LN(x))` in `MLP(LN(x + Attention(LN(x))))` (plus another skip connection). """ def __init__( self, config: OlmoConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -145,14 +144,14 @@ def forward( class OlmoMLP(nn.Module): """ This is the MLP block where the output is computed as - ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + `MLP(LN(x))` in `MLP(LN(x + Attention(LN(x))))` (plus another skip connection). """ def __init__( self, config: OlmoConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -194,15 +193,15 @@ def forward( class OlmoDecoderLayer(nn.Module): """ This is a typical transformer block where the output is - computed as ``MLP(LN(x + Attention(LN(x))))`` + computed as `MLP(LN(x + Attention(LN(x))))` (plus another skip connection). """ def __init__( self, config: OlmoConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -226,7 +225,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]: + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: # Attention block. residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -276,9 +275,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: """ :param input_ids: A tensor of shape `(batch_size, seq_len)`. """ @@ -389,9 +388,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids=input_ids, positions=positions, @@ -403,7 +402,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 79234cc4dd8d..7e39f6dff25e 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -27,7 +27,6 @@ from collections.abc import Iterable from functools import partial from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -70,7 +69,7 @@ class Olmo2Attention(nn.Module): """ This is the attention block where the output is computed as - ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + `Attention(LN(x))` in `MLP(LN(x + Attention(LN(x))))` (plus another skip connection). """ @@ -191,7 +190,7 @@ def forward( class Olmo2MLP(nn.Module): """ This is the MLP block where the output is computed as - ``MLP(x)`` in ``LN(MLP(x + LN(Attention(x))))`` + `MLP(x)` in `LN(MLP(x + LN(Attention(x))))` (plus another skip connection). """ @@ -236,7 +235,7 @@ def forward( class Olmo2DecoderLayer(nn.Module): """ This is a typical transformer block where the output is - computed as ``MLP(LN(x + Attention(LN(x))))`` + computed as `MLP(LN(x + Attention(LN(x))))` (plus another skip connection). """ @@ -312,9 +311,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: """ :param input_ids: A tensor of shape `(batch_size, seq_len)`. """ @@ -429,9 +428,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids=input_ids, positions=positions, @@ -443,7 +442,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 90ec1a890417..7f867244330f 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -17,15 +17,13 @@ from collections.abc import Iterable from functools import partial from itertools import islice -from typing import Any, Optional, Union import torch from torch import nn -from transformers import OlmoeConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import VllmConfig from vllm.distributed import ( get_pp_group, get_tensor_model_parallel_rank, @@ -51,7 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, is_pp_missing_parameter, @@ -78,9 +76,9 @@ def __init__( top_k: int, hidden_size: int, intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + tp_size: int | None = None, prefix: str = "", ): super().__init__() @@ -117,20 +115,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class OlmoeAttention(nn.Module): - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, - max_position_embeddings: int = 4096, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() - self.hidden_size = hidden_size + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 4096) + + num_heads = config.num_attention_heads + num_kv_heads = config.num_key_value_heads + tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 @@ -145,7 +144,7 @@ def __init__( # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -153,7 +152,7 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( - hidden_size, + self.hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, @@ -166,7 +165,7 @@ def __init__( self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, eps=1e-5) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, - hidden_size, + self.hidden_size, bias=False, quant_config=quant_config, ) @@ -218,28 +217,15 @@ def forward( class OlmoeDecoderLayer(nn.Module): - def __init__( - self, - config: OlmoeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", 4096) self.self_attn = OlmoeAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, + vllm_config=vllm_config, prefix=f"{prefix}.self_attn", ) @@ -258,7 +244,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> torch.Tensor: # Self Attention if residual is None: @@ -280,12 +266,16 @@ def forward( @support_torch_compile class OlmoeModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = OlmoeDecoderLayer, + ): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config self.vocab_size = config.vocab_size self.config = config @@ -295,9 +285,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: OlmoeDecoderLayer( - config, cache_config, quant_config, prefix=prefix - ), + lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=1e-5) @@ -313,9 +301,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -339,7 +327,10 @@ def forward( {"hidden_states": hidden_states, "residual": residual} ) - hidden_states, _ = self.norm(hidden_states, residual) + if residual is not None: + hidden_states, _ = self.norm(hidden_states, residual) + else: + hidden_states = self.norm(hidden_states) return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: @@ -358,8 +349,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) @@ -442,27 +431,31 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded_params -class OlmoeForCausalLM(nn.Module, SupportsPP): +class OlmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + ] } - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = OlmoeDecoderLayer, + ): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.model = OlmoeModel( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + layer_type=layer_type, ) self.lm_head = ParallelLMHead( config.vocab_size, @@ -483,9 +476,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index eadfea6084e5..d124b7671b9c 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -22,7 +22,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -77,8 +76,8 @@ def __init__( embed_dim: int, num_heads: int, bias: bool = True, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -129,8 +128,8 @@ class OPTDecoderLayer(nn.Module): def __init__( self, config: OPTConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -202,8 +201,8 @@ class OPTDecoder(nn.Module): def __init__( self, config: OPTConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -270,9 +269,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is None: inputs_embeds = self.get_input_embeddings(input_ids) @@ -319,9 +318,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: return self.decoder( input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds ) @@ -402,9 +401,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -413,7 +412,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 0ce172938955..cfe4d0333418 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -9,7 +9,7 @@ from collections.abc import Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -51,7 +51,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -80,10 +80,10 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -158,8 +158,8 @@ class OrionDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -250,9 +250,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -340,9 +340,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -351,7 +351,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index 12ed7b4c2ed0..cc6c9b4e72d7 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -20,7 +20,7 @@ import math from collections.abc import Iterable, Mapping -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal import torch import torch.nn as nn @@ -87,7 +87,7 @@ class VisualTokenizer(torch.nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -114,7 +114,7 @@ def __init__( def _init_backbone( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: model_type = config.backbone_config.model_type @@ -166,7 +166,7 @@ def encode(self, pixel_values: torch.Tensor) -> torch.Tensor: # e.g., for hidden_stride=2, this leads to a token length reduction: # 1024 -> 256 for aimv2 if self.config.hidden_stride > 1: - # this `d` maybe different from the above `d`` + # this `d` maybe different from the above `d` n, L, d = features.shape sqrt_l = int(L**0.5) assert sqrt_l**2 == L, ( @@ -282,7 +282,7 @@ def get_image_pad_token(self) -> str: text_model_type = hf_text_config.model_type return IMAGE_PAD_TOKEN_MAP.get(text_model_type) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_image_size_with_most_features(self) -> ImageSize: @@ -302,7 +302,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) @@ -417,7 +417,7 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): merge_by_field_config = True @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<image>" @@ -453,7 +453,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[OvisImagePatchInputs]: + ) -> OvisImagePatchInputs | None: pixel_values = kwargs.pop("pixel_values", None) indicator_tokens = kwargs.pop("indicator_tokens", None) @@ -527,10 +527,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None @@ -547,7 +547,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.llm.compute_logits(hidden_states) return logits diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index bb4fb1d17c15..f6461ae9a412 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -4,12 +4,13 @@ from collections.abc import Iterable, Mapping from functools import partial -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal import torch import torch.nn as nn from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig +from vllm.attention.backends.registry import _Backend from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.linear import ReplicatedLinear @@ -102,9 +103,10 @@ def __init__( self, config: PretrainedConfig, visual_vocab_size: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ): super().__init__() self.config = config @@ -113,6 +115,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.vit", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) # reserved tokens for INDICATOR_IDS head_dim = visual_vocab_size - len(INDICATOR_IDS) @@ -129,9 +132,10 @@ def __init__( def _init_backbone( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ): model_type = config.model_type if model_type == "siglip2_navit": @@ -140,6 +144,7 @@ def _init_backbone( quant_config=quant_config, prefix=prefix, use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}") @@ -205,7 +210,7 @@ def get_image_pad_token(self) -> str: def get_image_processor(self) -> BaseImageProcessor: return self.get_hf_processor().image_processor # type: ignore - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None, "video": 1} def get_image_size_with_most_features(self) -> ImageSize: @@ -274,7 +279,7 @@ def get_num_video_tokens( image_width: int, image_height: int, num_frames: int, - image_processor: Optional[BaseImageProcessor], + image_processor: BaseImageProcessor | None, ) -> int: num_video_tokens = self.get_num_image_tokens( image_width=image_width, image_height=image_height, num_frames=num_frames @@ -305,7 +310,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -457,6 +462,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config self.config: PretrainedConfig = config self.llm = init_vllm_registered_model( @@ -464,11 +470,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "llm"), ) + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.visual_tokenizer = VisualTokenizer( config=config.vit_config, visual_vocab_size=config.visual_vocab_size, quant_config=quant_config, prefix=f"{prefix}.visual_tokenizer", + attn_backend_override=attn_backend_override, ) self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size) @@ -482,7 +494,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[Ovis2_5ImagePatchInputs]: + ) -> Ovis2_5ImagePatchInputs | None: pixel_values = kwargs.pop("pixel_values", None) indicator_tokens = kwargs.pop("indicator_tokens", None) grids = kwargs.pop("grids", None) @@ -516,7 +528,7 @@ def _parse_and_validate_image_input( def _parse_and_validate_video_input( self, **kwargs: object - ) -> Optional[Ovis2_5VideoPatchInputs]: + ) -> Ovis2_5VideoPatchInputs | None: pixel_values = kwargs.pop("video_pixel_values", None) indicator_tokens = kwargs.pop("video_indicator_tokens", None) grids = kwargs.pop("video_grids", None) @@ -549,7 +561,7 @@ def _parse_and_validate_video_input( raise AssertionError("This line should be unreachable.") def _process_visual_input( - self, visual_input: Union[Ovis2_5ImagePatchInputs, Ovis2_5VideoPatchInputs] + self, visual_input: Ovis2_5ImagePatchInputs | Ovis2_5VideoPatchInputs ) -> MultiModalEmbeddings: image_patches_flat = visual_input["flat_data"] patches_per_image = visual_input["patches_per_item"] @@ -616,12 +628,12 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: for modality in modalities: if modality == "images": image_input = modalities["images"] - vision_embeddings = self._process_visual_input(image_input) - multimodal_embeddings += vision_embeddings + image_embeddings = self._process_visual_input(image_input) + multimodal_embeddings += tuple(image_embeddings) if modality == "videos": video_input = modalities["videos"] video_embeddings = self._process_visual_input(video_input) - multimodal_embeddings += video_embeddings + multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings @@ -629,10 +641,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None @@ -649,7 +661,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.llm.compute_logits(hidden_states) return logits diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 7bddfc5ee855..fb0b4b290467 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal, TypeAlias import torch from torch import nn @@ -74,7 +74,9 @@ class PaliGemmaImageEmbeddingInputs(TensorSchema): data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] -PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs, PaliGemmaImageEmbeddingInputs] +PaliGemmaImageInputs: TypeAlias = ( + PaliGemmaImagePixelInputs | PaliGemmaImageEmbeddingInputs +) class PaliGemmaMultiModalProjector(nn.Module): @@ -95,7 +97,7 @@ def get_hf_config(self): def get_vision_encoder_info(self): return get_vision_encoder_info(self.get_hf_config()) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": 1} def get_num_image_tokens( @@ -120,7 +122,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: hf_config = self.info.get_hf_config() vision_config = hf_config.vision_config @@ -217,11 +219,11 @@ def get_insertion(item_idx: int): def apply( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Optional[Mapping[str, object]] = None, - mm_uuids: Optional[MultiModalUUIDDict] = None, + tokenization_kwargs: Mapping[str, object] | None = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> MultiModalInputs: mm_inputs = super().apply( prompt, @@ -273,7 +275,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return None @@ -317,7 +319,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[PaliGemmaImageInputs]: + ) -> PaliGemmaImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -386,8 +388,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> IntermediateTensors: if intermediate_tensors is not None: @@ -402,7 +404,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index d3df5f9a59b5..2c62f6862cf2 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -25,7 +25,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -63,7 +62,7 @@ class PersimmonMLP(nn.Module): def __init__( - self, config: PersimmonConfig, quant_config: Optional[QuantizationConfig] = None + self, config: PersimmonConfig, quant_config: QuantizationConfig | None = None ): super().__init__() self.dense_h_to_4h = ColumnParallelLinear( @@ -85,8 +84,8 @@ class PersimmonAttention(nn.Module): def __init__( self, config: PersimmonConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -181,8 +180,8 @@ class PersimmonDecoderLayer(nn.Module): def __init__( self, config: PersimmonConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -263,9 +262,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -340,8 +339,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ): hidden_states = self.model( input_ids=input_ids, @@ -354,7 +353,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 779b391008bb..6adcaf5084cb 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -40,7 +40,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -80,8 +79,8 @@ class PhiAttention(nn.Module): def __init__( self, config: PhiConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -149,7 +148,7 @@ def forward( class PhiMLP(nn.Module): def __init__( - self, config: PhiConfig, quant_config: Optional[QuantizationConfig] = None + self, config: PhiConfig, quant_config: QuantizationConfig | None = None ): super().__init__() @@ -179,8 +178,8 @@ class PhiLayer(nn.Module): def __init__( self, config: PhiConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -241,9 +240,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -348,9 +347,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -360,7 +359,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states, self.lm_head.bias) return logits diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index d972604db9cd..b86fe67fb476 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal, TypeAlias import regex as re import torch @@ -56,7 +56,6 @@ ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel @@ -70,7 +69,6 @@ AutoWeightsLoader, WeightsMapper, _merge_multimodal_embeddings, - flatten_bn, init_vllm_registered_model, maybe_prefix, ) @@ -96,7 +94,7 @@ def _init_img_processor( hf_config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, prefix: str = "", ) -> CLIPVisionModel: clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG @@ -132,14 +130,14 @@ class Phi3VImagePixelInputs(TensorSchema): # Supports either a stacked tensor or a list of (p, 3, h, w) tensors pixel_values: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape( "bn", "p", 3, "h", "w", dynamic_dims={"p"} ), # 'p' may vary across items ] # Stacked tensor with height and width for each image - image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] + image_sizes: Annotated[torch.Tensor | None, TensorShape("bn", 2)] class Phi3VImageEmbeddingInputs(TensorSchema): @@ -153,12 +151,12 @@ class Phi3VImageEmbeddingInputs(TensorSchema): type: Literal["image_embeds"] = "image_embeds" data: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("bn", "f", "h"), ] -Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs] +Phi3VImageInputs: TypeAlias = Phi3VImagePixelInputs | Phi3VImageEmbeddingInputs class Phi3ImageEmbeddingBase(nn.Module): @@ -192,7 +190,7 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, prefix: str = "", ) -> None: super().__init__() @@ -350,7 +348,7 @@ def add_image_newline(self, image_features_hd): class Phi3VProcessingInfo(BaseProcessingInfo): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_num_image_tokens( @@ -358,7 +356,7 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - processor: Optional[ProcessorMixin] = None, + processor: ProcessorMixin | None = None, ) -> int: if processor is None: processor = self.get_hf_processor() @@ -386,7 +384,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) @@ -564,6 +562,8 @@ def _apply_prompt_updates( dummy_inputs=Phi3VDummyInputsBuilder, ) class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant): + merge_by_field_config = True + hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.vision_embed_tokens.wte": "embed_tokens", @@ -574,7 +574,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return f"<|image_{i}|>" @@ -620,7 +620,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[Phi3VImageInputs]: + ) -> Phi3VImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) @@ -631,8 +631,8 @@ def _parse_and_validate_image_input( if pixel_values is not None: return Phi3VImagePixelInputs( type="pixel_values", - pixel_values=flatten_bn(pixel_values), - image_sizes=flatten_bn(image_sizes, concat=True), + pixel_values=pixel_values, + image_sizes=image_sizes, resolve_bindings={ "h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size, "w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size, @@ -642,7 +642,7 @@ def _parse_and_validate_image_input( if image_embeds is not None: return Phi3VImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) raise AssertionError("This line should be unreachable.") @@ -652,19 +652,10 @@ def _process_image_input( image_input: Phi3VImageInputs, ) -> torch.Tensor: if image_input["type"] == "image_embeds": - image_data = image_input["data"] - if is_list_of(image_data, torch.Tensor): - # it's already a list of tensors - return image_data - if len(image_data.shape) == 3: - # 3D tensor - return list(torch.unbind(image_data, dim=0)) - raise ValueError( - "We expect batched 2D tensors; " - "this can be either a list of 2D tensors or a single 3D tensor." - ) + return image_input["data"] assert self.vision_embed_tokens is not None + image_embeds = self.vision_embed_tokens( image_input["pixel_values"], image_input["image_sizes"] ) @@ -684,9 +675,9 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, *, - is_multimodal: Optional[torch.Tensor] = None, + is_multimodal: torch.Tensor | None = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: inputs_embeds = self._get_text_embeddings( @@ -716,8 +707,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ): if intermediate_tensors is not None: @@ -732,7 +723,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index 002233d0677b..4799b7aba7f7 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal, TypeAlias import numpy as np import torch @@ -64,7 +64,6 @@ ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer @@ -72,7 +71,6 @@ from .utils import ( AutoWeightsLoader, WeightsMapper, - flatten_bn, init_vllm_registered_model, maybe_prefix, ) @@ -147,7 +145,7 @@ def __init__(self, config: Phi4MultimodalConfig): def get_img_features( self, img_embeds: torch.FloatTensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: torch.Tensor | None = None, ) -> torch.FloatTensor: img_feature = self.img_processor( img_embeds, patch_attention_mask=attention_mask @@ -172,8 +170,8 @@ def get_img_features( def forward( self, image_pixel_values: torch.FloatTensor, - image_sizes: Optional[torch.Tensor] = None, - image_attention_mask: Optional[torch.Tensor] = None, + image_sizes: torch.Tensor | None = None, + image_attention_mask: torch.Tensor | None = None, ) -> torch.FloatTensor: image_pixel_values = image_pixel_values.to( self.img_processor.embeddings.patch_embedding.weight.dtype @@ -278,7 +276,7 @@ class Phi4MultimodalAudioMLP(nn.Module): def __init__( self, config: Phi4MultimodalAudioConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -311,7 +309,7 @@ class Phi4MultimodalAudioAttention(nn.Module): def __init__( self, config: Phi4MultimodalAudioConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -522,7 +520,7 @@ def calculate_hs_mask( pad_mask = pad_mask & enc_streaming_mask return pad_mask - def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor] = None): + def forward(self, hidden_states: torch.Tensor, mask: torch.Tensor | None = None): hidden_states = self.encoder_embedding(hidden_states) hidden_states, hs_mask, mask = self.forward_embeddings(hidden_states, mask) @@ -672,8 +670,8 @@ class Phi4MMImagePixelInputs(TensorSchema): type: Literal["pixel_values"] - data: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + pixel_values: Annotated[ + torch.Tensor | list[torch.Tensor], TensorShape( "bn", "p", 3, "h", "w", dynamic_dims={"p"} ), # may be different per batch and image @@ -706,7 +704,7 @@ class Phi4MMImageEmbeddingInputs(TensorSchema): type: Literal["image_embeds"] data: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("bn", "f", "h"), ] @@ -721,8 +719,8 @@ class Phi4MMAudioFeatureInputs(TensorSchema): type: Literal["audio_features"] - data: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + audio_features: Annotated[ + torch.Tensor | list[torch.Tensor], TensorShape("bn", "t", 80, dynamic_dims={"t"}), ] @@ -744,8 +742,8 @@ class Phi4MMAudioEmbeddingInputs(TensorSchema): ] -Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs] -Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs] +Phi4MMImageInput: TypeAlias = Phi4MMImagePixelInputs | Phi4MMImageEmbeddingInputs +Phi4MMAudioInputs: TypeAlias = Phi4MMAudioFeatureInputs | Phi4MMAudioEmbeddingInputs def cat_with_pad(tensors, dim, padding_value=0): @@ -786,7 +784,7 @@ def get_feature_extractor(self, **kwargs: object) -> Phi4MultimodalFeatureExtrac def get_image_processor( self, - processor: Optional[Phi4MMProcessor] = None, + processor: Phi4MMProcessor | None = None, ) -> Phi4MultimodalImageProcessorFast: if processor is None: processor = self.get_hf_processor() @@ -794,11 +792,11 @@ def get_image_processor( def get_dynamic_hd( self, - processor: Optional[Phi4MMProcessor] = None, + processor: Phi4MMProcessor | None = None, ) -> int: return self.get_image_processor(processor).dynamic_hd - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"audio": None, "image": None} def _find_target_aspect_ratio( @@ -936,7 +934,7 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - processor: Optional[Phi4MMProcessor] = None, + processor: Phi4MMProcessor | None = None, ) -> int: hf_config = self.get_hf_config() vision_config = hf_config.vision_config @@ -959,7 +957,7 @@ def get_num_image_tokens( def get_image_size_with_most_features( self, - processor: Optional[Phi4MMProcessor] = None, + processor: Phi4MMProcessor | None = None, ) -> ImageSize: vit_image_size = self.get_hf_config().vision_config.image_size @@ -1038,7 +1036,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) num_images = mm_counts.get("image", 0) @@ -1189,6 +1187,8 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): Implements the Phi-4-multimodal-instruct model in vLLM. """ + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": [ "qkv_proj", @@ -1216,7 +1216,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|image|>" if modality.startswith("audio"): @@ -1253,7 +1253,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def _parse_and_validate_audio_input( self, **kwargs: object - ) -> Optional[Phi4MMAudioInputs]: + ) -> Phi4MMAudioInputs | None: """ Parse and validate the audio input to the model. This handles both audio features and audio embeddings, but only the former is used for @@ -1273,7 +1273,8 @@ def _parse_and_validate_audio_input( if audio_features is not None: return Phi4MMAudioFeatureInputs( - type="audio_features", data=flatten_bn(audio_features) + type="audio_features", + audio_features=audio_features, ) if audio_embeds is not None: @@ -1298,7 +1299,7 @@ def _process_audio_input( if audio_input["type"] == "audio_embeds": return audio_input["data"] - audio_features = audio_input["data"] + audio_features = audio_input["audio_features"] # (e.g. multiple examples) and the second dim is the multi-audio dim # (e.g. multiple audios in the same example) @@ -1314,9 +1315,9 @@ def _process_audio_input( def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[Phi4MMImagePixelInputs]: - image_pixel_values: NestedTensors = kwargs.get("image_pixel_values") - if image_pixel_values is None: + ) -> Phi4MMImagePixelInputs | None: + pixel_values = kwargs.get("image_pixel_values") + if pixel_values is None: return None image_sizes = kwargs.get("image_sizes") @@ -1328,52 +1329,9 @@ def _parse_and_validate_image_input( and num_img_tokens is not None ), "Missing image inputs" - if is_list_of(image_pixel_values, torch.Tensor): - assert all(p.dim() == 5 for p in image_pixel_values), ( - "Incorrect image inputs" - ) - # list len is batch_size. - # each tensor has dimension: num_img_per_example, num_hd_patches, - # channels, height, width. - # need to pad along num_hd_patches. - # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w. - image_pixel_values = cat_with_pad(image_pixel_values, dim=0) - elif isinstance(image_pixel_values, torch.Tensor): - # dimension: batch_size, num_img_per_example, num_hd_patches, - # channels, height, width. - # we flatten first 2 dims to make it a single large batch for - # SigLIP Encoder. - assert image_pixel_values.dim() == 6, "Incorrect image inputs" - image_pixel_values = image_pixel_values.flatten(0, 1) - else: - raise ValueError("Incorrect image_pixel_values inputs") - - if isinstance(image_attention_mask, list): - image_attention_mask = cat_with_pad(image_attention_mask, dim=0) - elif isinstance(image_attention_mask, torch.Tensor): - image_attention_mask = image_attention_mask.flatten(0, 1) - else: - raise ValueError("Incorrect image_attention_mask inputs") - - if isinstance(image_sizes, list): - image_sizes = torch.cat(image_sizes, dim=0) - elif isinstance(image_sizes, torch.Tensor): - image_sizes = image_sizes.flatten(0, 1) - else: - raise ValueError("Incorrect image_sizes inputs") - - if isinstance(num_img_tokens, list): - num_img_tokens = [ - n for num_tensor in num_img_tokens for n in num_tensor.tolist() - ] - elif isinstance(num_img_tokens, torch.Tensor): - num_img_tokens = num_img_tokens.flatten(0, 1).tolist() - else: - raise ValueError("Incorrect num_img_tokens inputs") - return Phi4MMImagePixelInputs( type="pixel_values", - data=image_pixel_values, + pixel_values=pixel_values, image_sizes=image_sizes, image_attention_mask=image_attention_mask, num_img_tokens=num_img_tokens, @@ -1405,7 +1363,7 @@ def _process_image_input( image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: dtype = next(self.image_embed.parameters()).dtype - pixel_values = image_input["data"].to(dtype) + pixel_values = image_input["pixel_values"].to(dtype) image_sizes = image_input["image_sizes"] image_attention_mask = image_input["image_attention_mask"] image_embeds = self.image_embed( @@ -1430,8 +1388,8 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: if modality == "images": audio_projection_mode = "vision" image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) - multimodal_embeddings += tuple(vision_embeddings) + image_embeddings = self._process_image_input(image_input) + multimodal_embeddings += tuple(image_embeddings) if modality == "audios": audio_input = modalities["audios"] audio_embeddings = self._process_audio_input( @@ -1445,8 +1403,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor: if intermediate_tensors is not None: @@ -1464,7 +1422,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 981f9b37846f..acad72b058fc 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal, TypeAlias import numpy as np import torch @@ -50,13 +50,12 @@ ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal from .phi4mm_audio import AudioEmbedding -from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix # <|endoftext10|> (see vocab.json in hf model) _IMAGE_PLACEHOLDER_TOKEN_ID = 200010 @@ -122,7 +121,7 @@ class Phi4MMImageEncoder(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, prefix: str = "", model_dir: str = "", ) -> None: @@ -467,8 +466,8 @@ class Phi4MMImagePixelInputs(TensorSchema): type: Literal["pixel_values"] - data: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + pixel_values: Annotated[ + torch.Tensor | list[torch.Tensor], TensorShape( "bn", "p", 3, "h", "w", dynamic_dims={"p"} ), # may be different per batch and image @@ -499,8 +498,8 @@ class Phi4MMAudioFeatureInputs(TensorSchema): type: Literal["audio_features"] - data: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + audio_features: Annotated[ + torch.Tensor | list[torch.Tensor], TensorShape("bn", "t", 80, dynamic_dims={"t"}), ] @@ -521,7 +520,7 @@ class Phi4MMAudioEmbeddingInputs(TensorSchema): ] -Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs] +Phi4MMAudioInputs: TypeAlias = Phi4MMAudioFeatureInputs | Phi4MMAudioEmbeddingInputs def cat_with_pad(tensors, dim, padding_value=0): @@ -561,7 +560,7 @@ def audio_tokens(self) -> list[str]: def get_dynamic_hd( self, - processor: Optional[ProcessorMixin] = None, + processor: ProcessorMixin | None = None, ) -> int: if processor is None: processor = self.get_hf_processor() @@ -571,7 +570,7 @@ def get_dynamic_hd( def get_feature_extractor(self, **kwargs: object) -> SequenceFeatureExtractor: return self.get_hf_processor(**kwargs).audio_processor - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"audio": None, "image": None} def _find_target_aspect_ratio( @@ -709,7 +708,7 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - processor: Optional[ProcessorMixin] = None, + processor: ProcessorMixin | None = None, ) -> int: hf_config = self.get_hf_config() vision_encoder_name = hf_config.img_processor @@ -735,7 +734,7 @@ def get_num_image_tokens( def get_image_size_with_most_features( self, - processor: Optional[ProcessorMixin] = None, + processor: ProcessorMixin | None = None, ) -> ImageSize: hf_config = self.get_hf_config() vision_encoder_name = hf_config.img_processor @@ -819,7 +818,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) num_images = mm_counts.get("image", 0) @@ -986,6 +985,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): Implements the Phi-4-multimodal-instruct model in vLLM. """ + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": [ "qkv_proj", @@ -1008,7 +1009,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return f"<|image_{i}|>" if modality.startswith("audio"): @@ -1074,7 +1075,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def _parse_and_validate_audio_input( self, **kwargs: object - ) -> Optional[Phi4MMAudioInputs]: + ) -> Phi4MMAudioInputs | None: """ Parse and validate the audio input to the model. This handles both audio features and audio embeddings, but only the former is used for @@ -1094,7 +1095,8 @@ def _parse_and_validate_audio_input( if audio_features is not None: return Phi4MMAudioFeatureInputs( - type="audio_features", data=flatten_bn(audio_features) + type="audio_features", + audio_features=audio_features, ) if audio_embeds is not None: @@ -1119,7 +1121,7 @@ def _process_audio_input( if audio_input["type"] == "audio_embeds": return audio_input["data"] - audio_features = audio_input["data"] + audio_features = audio_input["audio_features"] # (e.g. multiple examples) and the second dim is the multi-audio dim # (e.g. multiple audios in the same example) @@ -1135,9 +1137,9 @@ def _process_audio_input( def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[Phi4MMImagePixelInputs]: - input_image_embeds: NestedTensors = kwargs.get("input_image_embeds") - if input_image_embeds is None: + ) -> Phi4MMImagePixelInputs | None: + pixel_values = kwargs.get("input_image_embeds") + if pixel_values is None: return None image_sizes = kwargs.get("image_sizes") @@ -1149,52 +1151,9 @@ def _parse_and_validate_image_input( and num_img_tokens is not None ), "Missing image inputs" - if is_list_of(input_image_embeds, torch.Tensor): - assert all(p.dim() == 5 for p in input_image_embeds), ( - "Incorrect image inputs" - ) - # list len is batch_size. - # each tensor has dimension: num_img_per_example, num_hd_patches, - # channels, height, width. - # need to pad along num_hd_patches. - # mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w. - input_image_embeds = cat_with_pad(input_image_embeds, dim=0) - elif isinstance(input_image_embeds, torch.Tensor): - # dimension: batch_size, num_img_per_example, num_hd_patches, - # channels, height, width. - # we flatten first 2 dims to make it a single large batch for - # SigLIP Encoder. - assert input_image_embeds.dim() == 6, "Incorrect image inputs" - input_image_embeds = input_image_embeds.flatten(0, 1) - else: - raise ValueError("Incorrect input_image_embeds inputs") - - if isinstance(image_attention_mask, list): - image_attention_mask = cat_with_pad(image_attention_mask, dim=0) - elif isinstance(image_attention_mask, torch.Tensor): - image_attention_mask = image_attention_mask.flatten(0, 1) - else: - raise ValueError("Incorrect image_attention_mask inputs") - - if isinstance(image_sizes, list): - image_sizes = torch.cat(image_sizes, dim=0) - elif isinstance(image_sizes, torch.Tensor): - image_sizes = image_sizes.flatten(0, 1) - else: - raise ValueError("Incorrect image_sizes inputs") - - if isinstance(num_img_tokens, list): - num_img_tokens = [ - n for num_tensor in num_img_tokens for n in num_tensor.tolist() - ] - elif isinstance(num_img_tokens, torch.Tensor): - num_img_tokens = num_img_tokens.flatten(0, 1).tolist() - else: - raise ValueError("Incorrect num_img_tokens inputs") - return Phi4MMImagePixelInputs( type="pixel_values", - data=input_image_embeds, + pixel_values=pixel_values, image_sizes=image_sizes, image_attention_mask=image_attention_mask, num_img_tokens=num_img_tokens, @@ -1223,7 +1182,7 @@ def _process_image_input( self, image_input: Phi4MMImagePixelInputs ) -> list[torch.Tensor]: dtype = next(self.vision_encoder.parameters()).dtype - pixel_values = image_input["data"].to(dtype) + pixel_values = image_input["pixel_values"].to(dtype) image_sizes = image_input["image_sizes"] image_attention_mask = image_input["image_attention_mask"] image_embeds = self.vision_encoder( @@ -1248,8 +1207,8 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: if modality == "images": audio_projection_mode = "vision" image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) - multimodal_embeddings += tuple(vision_embeddings) + image_embeddings = self._process_image_input(image_input) + multimodal_embeddings += tuple(image_embeddings) if modality == "audios": audio_input = modalities["audios"] audio_embeddings = self._process_audio_input( @@ -1263,8 +1222,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor: if intermediate_tensors is not None: @@ -1282,7 +1241,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py index d289e26efa10..493fdb465fba 100644 --- a/vllm/model_executor/models/phi4mm_audio.py +++ b/vllm/model_executor/models/phi4mm_audio.py @@ -7,7 +7,7 @@ #!/usr/bin/env python3 import abc import math -from typing import Any, Literal, Optional, Union +from typing import Any, Literal import numpy as np import torch @@ -221,7 +221,7 @@ def forward( pos_k: torch.Tensor, pos_v: torch.Tensor, mask: torch.Tensor, - relative_attention_bias: Optional[Tensor] = None, + relative_attention_bias: Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ConformerEncoder forward. @@ -329,8 +329,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module): def __init__( self, input_size: int, - chunk_size: Union[int, list[int]], - left_chunk: Union[int, list[int]], + chunk_size: int | list[int], + left_chunk: int | list[int], attention_dim: int = 256, attention_heads: int = 4, input_layer: str = "nemo_conv", @@ -339,12 +339,12 @@ def __init__( time_reduction: int = 4, dropout_rate: float = 0.0, padding_idx: int = -1, - relative_attention_bias_args: Optional[dict[str, Any]] = None, + relative_attention_bias_args: dict[str, Any] | None = None, positional_dropout_rate: float = 0.0, - nemo_conv_settings: Optional[dict[str, Any]] = None, + nemo_conv_settings: dict[str, Any] | None = None, conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", attention_group_size: int = 1, - encoder_embedding_config: Optional[dict[str, Any]] = None, + encoder_embedding_config: dict[str, Any] | None = None, ) -> None: super().__init__() self.input_size = input_size @@ -411,8 +411,8 @@ def __init__( ) def compute_lens_change( - self, feature_lens: Union[int, torch.Tensor] - ) -> Union[int, torch.Tensor]: + self, feature_lens: int | torch.Tensor + ) -> int | torch.Tensor: """feature_lens: int return updated feature lens. @@ -452,8 +452,8 @@ def forward(self) -> Any: def _chunk_size_selection( self, - chunk_size: Optional[Union[int, list[int]]] = None, - left_chunk: Optional[Union[int, list[int]]] = None, + chunk_size: int | list[int] | None = None, + left_chunk: int | list[int] | None = None, ) -> tuple[int, int]: """If chunk size is a list, we will randomly select a chunk size.""" @@ -503,7 +503,7 @@ def _forward_embeddings_core( def _position_embedding( self, input_tensor: torch.Tensor - ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: pos_k = None pos_v = None if self.relative_attention_bias_layer is None: @@ -516,8 +516,8 @@ def _streaming_mask( self, seq_len: int, batch_size: int, - chunk_size: Union[int, list[int]], - left_chunk: Union[int, list[int]], + chunk_size: int | list[int], + left_chunk: int | list[int], ) -> torch.Tensor: chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection( chunk_size, left_chunk @@ -540,25 +540,25 @@ def forward_embeddings( self, xs_pad: torch.Tensor, masks: torch.Tensor, - chunk_size_nc: Optional[Union[int, list[int]]] = None, - left_chunk_nc: Optional[Union[int, list[int]]] = None, - ) -> Union[ + chunk_size_nc: int | list[int] | None = None, + left_chunk_nc: int | list[int] | None = None, + ) -> ( tuple[ torch.Tensor, - Optional[torch.Tensor], - Optional[torch.Tensor], + torch.Tensor | None, + torch.Tensor | None, torch.Tensor, torch.Tensor, - ], - tuple[ + ] + | tuple[ torch.Tensor, - Optional[torch.Tensor], - Optional[torch.Tensor], + torch.Tensor | None, + torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor, - ], - ]: + ] + ): """Forwarding the inputs through the top embedding layers Args: @@ -803,9 +803,9 @@ class ConformerEncoder(TransformerEncoderBase): def __init__( # pylint: disable-all self, input_size: int, - chunk_size: Union[int, list[int]], - left_chunk: Union[int, list[int]], - num_lang: Optional[int] = None, + chunk_size: int | list[int], + left_chunk: int | list[int], + num_lang: int | None = None, attention_dim: int = 256, attention_heads: int = 4, linear_units: int = 2048, @@ -832,14 +832,14 @@ def __init__( # pylint: disable-all extra_layer_output_idx: int = -1, extra_multi_layer_output_idxs: list[int] = [], # noqa activation_checkpointing: str = "", - relative_attention_bias_args: Optional[dict[str, Any]] = None, + relative_attention_bias_args: dict[str, Any] | None = None, time_reduction: int = 4, use_pt_scaled_dot_product_attention: bool = False, - nemo_conv_settings: Optional[dict[str, Any]] = None, + nemo_conv_settings: dict[str, Any] | None = None, conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", replication_pad_for_subsample_embedding: bool = False, attention_group_size: int = 1, - encoder_embedding_config: Optional[dict[str, Any]] = None, + encoder_embedding_config: dict[str, Any] | None = None, ) -> None: super().__init__( input_size, @@ -908,12 +908,12 @@ def __init__( # pylint: disable-all def init_relative_attention_bias( self, input_tensor: torch.Tensor - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: if self.relative_attention_bias_layer: return self.relative_attention_bias_layer(input_tensor) def calculate_hs_mask( - self, xs_pad: torch.Tensor, device: torch.device, mask: Optional[torch.Tensor] + self, xs_pad: torch.Tensor, device: torch.device, mask: torch.Tensor | None ) -> torch.Tensor: max_audio_length = xs_pad.shape[1] batch_size = xs_pad.shape[0] @@ -1066,9 +1066,9 @@ def __init__( def forward( self, audio_embed: torch.Tensor, - mask: Optional[torch.Tensor], - embed_len: Optional[int] = None, - ) -> tuple[torch.Tensor, Optional[int]]: + mask: torch.Tensor | None, + embed_len: int | None = None, + ) -> tuple[torch.Tensor, int | None]: """forward decoder""" # audio_embed: N x T x D => N x D x T @@ -1224,7 +1224,7 @@ def set_audio_embed_sizes(self, audio_embed_sizes: torch.Tensor) -> None: def get_audio_features( self, input_embeds: torch.Tensor, - audio_attention_mask: Optional[torch.Tensor] = None, + audio_attention_mask: torch.Tensor | None = None, audio_projection_mode: str = "speech", ) -> torch.Tensor: """ @@ -1278,7 +1278,7 @@ def get_audio_features( def forward( self, audio_features: torch.Tensor, - audio_attention_mask: Optional[torch.Tensor] = None, + audio_attention_mask: torch.Tensor | None = None, audio_projection_mode: str = "speech", ) -> torch.Tensor: """ diff --git a/vllm/model_executor/models/phi4mm_utils.py b/vllm/model_executor/models/phi4mm_utils.py index d50547c199ac..698435eb76c9 100644 --- a/vllm/model_executor/models/phi4mm_utils.py +++ b/vllm/model_executor/models/phi4mm_utils.py @@ -6,7 +6,6 @@ # but implemented by the Phi-Speech team #!/usr/bin/env python3 import math -from typing import Optional, Union import torch import torch.nn.functional as F @@ -917,7 +916,7 @@ def __init__( out_channels: int, kernel_size: int, stride: int = 1, - padding: Union[str, int] = 0, + padding: str | int = 0, dilation: int = 1, groups: int = 1, bias: bool = True, @@ -962,8 +961,8 @@ def __init__( ) def update_cache( - self, x: Tensor, cache: Optional[Tensor] = None - ) -> tuple[Tensor, Optional[Tensor]]: + self, x: Tensor, cache: Tensor | None = None + ) -> tuple[Tensor, Tensor | None]: if cache is None: new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) next_cache = cache @@ -978,8 +977,8 @@ def update_cache( return new_x, next_cache def forward( - self, x: Tensor, cache: Optional[Tensor] = None - ) -> Union[Tensor, tuple[Tensor, Optional[Tensor]]]: + self, x: Tensor, cache: Tensor | None = None + ) -> Tensor | tuple[Tensor, Tensor | None]: x, cache = self.update_cache(x, cache=cache) x = super().forward(x) if cache is None: @@ -1002,7 +1001,7 @@ def __init__( out_channels: int, kernel_size: int, stride: int = 1, - padding: Union[str, int] = 0, + padding: str | int = 0, dilation: int = 1, groups: int = 1, bias: bool = True, @@ -1371,9 +1370,7 @@ def get_sampling_frames(self) -> list[int]: def get_streaming_cache_size(self) -> list[int]: return [0, self.subsampling_factor + 1] - def forward( - self, x: Tensor, mask: Optional[Tensor] - ) -> tuple[Tensor, Optional[Tensor]]: + def forward(self, x: Tensor, mask: Tensor | None) -> tuple[Tensor, Tensor | None]: """ Forward method for NeMo subsampling. @@ -1615,10 +1612,10 @@ def set_export(self, mode: bool = True) -> None: def forward( self, x: Tensor, - memory: Optional[Tensor] = None, - pos_emb: Optional[Tensor] = None, - att_mask: Optional[Tensor] = None, - ) -> tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + memory: Tensor | None = None, + pos_emb: Tensor | None = None, + att_mask: Tensor | None = None, + ) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: """AttModule forward Args: @@ -1640,7 +1637,7 @@ def memory_dims(self, max_len: bool = False) -> tuple[int, int]: def masked_softmax( scores: Tensor, - mask: Optional[Tensor], + mask: Tensor | None, ) -> Tensor: if mask is not None: mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) @@ -1720,7 +1717,7 @@ def __init__( self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size) self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value) - self.attn = torch.jit.Attribute(None, Optional[Tensor]) + self.attn = torch.jit.Attribute(None, Tensor | None) self.dropout = nn.Dropout(p=dropout_rate) self.dropout_rate = dropout_rate self.use_pt_scaled_dot_product_attention = use_pt_scaled_dot_product_attention @@ -1741,10 +1738,10 @@ def forward( query: Tensor, key: Tensor, value: Tensor, - pos_k: Optional[Tensor], - pos_v: Optional[Tensor], - mask: Optional[Tensor], - relative_attention_bias: Optional[Tensor] = None, + pos_k: Tensor | None, + pos_v: Tensor | None, + mask: Tensor | None, + relative_attention_bias: Tensor | None = None, ) -> Tensor: """Compute 'Scaled Dot Product Attention'. diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index fee52edfe26c..2cd4d8c72721 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -26,7 +26,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -257,9 +256,9 @@ def __init__( top_k: int, hidden_size: int, intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + tp_size: int | None = None, prefix: str = "", ): super().__init__() @@ -304,12 +303,12 @@ def __init__( hidden_size: int, num_heads: int, num_kv_heads: int, - head_dim: Optional[int] = None, + head_dim: int | None = None, max_position: int = 4096 * 32, rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[dict] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + rope_scaling: dict | None = None, prefix: str = "", ) -> None: super().__init__() @@ -386,8 +385,8 @@ class PhiMoEDecoderLayer(nn.Module): def __init__( self, config: PhiMoEConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -427,7 +426,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> torch.Tensor: residual = hidden_states @@ -496,9 +495,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -674,9 +673,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 65abebcf37de..0555717017cd 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -5,12 +5,13 @@ from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields from functools import cached_property -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal import torch import torch.nn as nn import torch.nn.functional as F -from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage +from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk +from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.multimodal import ImageEncoder from PIL import Image @@ -99,7 +100,7 @@ class PixtralImagePixelInputs(TensorSchema): type: Literal["pixel_values"] = "pixel_values" images: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}), ] @@ -143,9 +144,9 @@ def patch_size(self) -> int: def __call__( self, - text: Optional[Union[TextInput, list[TextInput]]] = None, - images: Optional[Union[ImageInput, list[ImageInput]]] = None, - return_tensors: Optional[Union[str, TensorType]] = None, + text: TextInput | list[TextInput] | None = None, + images: ImageInput | list[ImageInput] | None = None, + return_tensors: str | TensorType | None = None, **kwargs, ) -> Mapping[str, NestedTensors]: if text is None: @@ -202,12 +203,12 @@ def get_tokenizer(self) -> MistralTokenizer: def get_hf_processor(self) -> PixtralProcessorAdapter: return PixtralProcessorAdapter(self.get_tokenizer()) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_vision_config( self, - processor: Optional[PixtralProcessorAdapter] = None, + processor: PixtralProcessorAdapter | None = None, ): if processor is None: processor = self.get_hf_processor() @@ -222,7 +223,7 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - processor: Optional[PixtralProcessorAdapter] = None, + processor: PixtralProcessorAdapter | None = None, ) -> int: if processor is None: processor = self.get_hf_processor() @@ -248,7 +249,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) @@ -269,7 +270,7 @@ def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> ProcessorInputs: tokenizer = self.info.get_tokenizer() @@ -341,11 +342,11 @@ def get_replacement(item_idx: int): def _cached_apply_hf_processor( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - mm_uuids: Optional[MultiModalUUIDDict] = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, @@ -368,7 +369,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) merge_by_field_config = True @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return None @@ -419,7 +420,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[PixtralImagePixelInputs]: + ) -> PixtralImagePixelInputs | None: images = kwargs.pop("images", None) if images is None: return None @@ -471,10 +472,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: """Run forward pass for pixtral.""" if intermediate_tensors is not None: inputs_embeds = None @@ -488,7 +489,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): @@ -716,7 +717,7 @@ def forward( self, x: torch.Tensor, mask: torch.Tensor, - freqs_cis: Optional[torch.Tensor], + freqs_cis: torch.Tensor | None, ) -> torch.Tensor: for layer in self.layers: x = layer(x, mask=mask, freqs_cis=freqs_cis) @@ -758,7 +759,7 @@ def __init__(self, args: VisionEncoderArgs): head_dim = self.args.hidden_size // self.args.num_attention_heads assert head_dim % 2 == 0, "ROPE requires even head_dim" - self._freqs_cis: Optional[torch.Tensor] = None + self._freqs_cis: torch.Tensor | None = None @property def max_patches_per_side(self) -> int: @@ -1014,7 +1015,7 @@ class PixtralHFMLP(nn.Module): def __init__( self, config: PixtralVisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, prefix: str = "", ) -> None: @@ -1048,7 +1049,7 @@ class PixtralHFAttention(nn.Module): def __init__( self, config: PixtralVisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, prefix: str = "", ) -> None: @@ -1083,7 +1084,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: batch, patches, _ = hidden_states.size() qkv_states, _ = self.qkv_proj(hidden_states) @@ -1118,7 +1119,7 @@ class PixtralHFTransformerBlock(nn.Module): def __init__( self, config: PixtralVisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, prefix: str = "", ) -> None: @@ -1154,9 +1155,9 @@ class PixtralHFTransformer(nn.Module): def __init__( self, config: PixtralVisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, - num_hidden_layers_override: Optional[int] = None, + num_hidden_layers_override: int | None = None, prefix: str = "", ) -> None: super().__init__() @@ -1201,10 +1202,10 @@ class PixtralHFVisionModel(nn.Module): def __init__( self, config: PixtralVisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, - num_hidden_layers_override: Optional[int] = None, - require_post_norm: Optional[bool] = None, + num_hidden_layers_override: int | None = None, + require_post_norm: bool | None = None, prefix: str = "", ) -> None: super().__init__() @@ -1246,8 +1247,8 @@ def forward( self, pixel_values: list[torch.Tensor], *, - select_layers: Optional[list[int]] = None, - feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, + select_layers: list[int] | None = None, + feature_select_strategy: VisionFeatureSelectStrategy | None = None, ) -> tuple[torch.Tensor, ...]: """ Args: diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 278957e7cf6c..09293f63f70e 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -4,7 +4,7 @@ from collections.abc import Iterable from itertools import islice -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -64,7 +64,7 @@ ) from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata @@ -504,7 +504,7 @@ class DenseMLP(nn.Module): def __init__( self, config: Plamo2Config, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -672,7 +672,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, **kwargs, ): if residual is None: @@ -728,7 +728,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> torch.Tensor: for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( @@ -770,8 +770,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -851,8 +851,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, ): hidden_states = self.model( @@ -901,7 +901,7 @@ def get_mamba_state_shape_from_config( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 6a12776b7f94..72e66d8f3038 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -10,7 +10,7 @@ import json from collections.abc import Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -55,7 +55,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str = "silu", - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -84,9 +84,9 @@ def __init__( num_heads: int, max_position_embeddings: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + rope_scaling: dict[str, Any] | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -144,8 +144,8 @@ class QWenBlock(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -174,7 +174,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -226,9 +226,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -288,7 +288,7 @@ def __init__( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits @@ -357,9 +357,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.transformer( input_ids, positions, intermediate_tensors, inputs_embeds ) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index c8bc17dbfa0a..b26546647ce7 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -27,7 +27,7 @@ from collections.abc import Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -77,7 +77,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -116,12 +116,12 @@ def __init__( num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[tuple] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + rope_scaling: tuple | None = None, prefix: str = "", attn_type: str = AttentionType.DECODER, - dual_chunk_attention_config: Optional[dict[str, Any]] = None, + dual_chunk_attention_config: dict[str, Any] | None = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -210,8 +210,8 @@ class Qwen2DecoderLayer(nn.Module): def __init__( self, config: Qwen2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -261,7 +261,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -362,9 +362,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -520,9 +520,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -531,7 +531,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 1ab2f43c9d73..a5d6004faf38 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -22,13 +22,14 @@ # limitations under the License. """Inference-only Qwen2.5-Omni model (thinker part).""" -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from copy import copy from functools import partial -from typing import Annotated, Any, Callable, Literal, Optional, Union +from typing import Annotated, Any, Literal import torch import torch.nn as nn +from transformers import PretrainedConfig from transformers.feature_extraction_utils import BatchFeature from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( Qwen2_5OmniConfig, @@ -45,7 +46,6 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger -from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VisionTransformer, @@ -93,6 +93,7 @@ from .interfaces import ( MultiModalEmbeddings, SupportsLoRA, + SupportsMRoPE, SupportsMultiModal, SupportsPP, ) @@ -101,7 +102,9 @@ WeightsMapper, init_vllm_registered_model, maybe_prefix, + split_list_into_ranges, ) +from .vision import get_llm_pos_ids_for_vision try: import flash_attn @@ -122,7 +125,7 @@ class Qwen2_5OmniAudioFeatureInputs(TensorSchema): type: Literal["audio_features"] input_features: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("nmb", "tsl"), ] @@ -188,7 +191,7 @@ def __init__(self, spatial_merge_size: int, *args, **kwargs): def _parse_audio_data( self, - data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + data: dict[str, torch.Tensor] | ModalityData[ImageItem], ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): return DictEmbeddingItems( @@ -222,7 +225,7 @@ def get_feature_extractor(self, **kwargs: object): assert isinstance(feature_extractor, WhisperFeatureExtractor) return feature_extractor - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"audio": None, "image": None, "video": None} @@ -250,7 +253,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) num_images = mm_counts.get("image", 0) @@ -412,6 +415,59 @@ def _maybe_apply_prompt_updates( return prompt_ids, mm_placeholders + @classmethod + def omni_get_updates_use_audio_in_video( + cls, + thinker_config: PretrainedConfig, + audio_len: int, + video_grid_thw: list[int] | torch.Tensor, + video_second_per_grid_t: float, + ) -> list[int]: + """Get video prompt updates when `use_audio_in_video` is True. + + In this case, audio and vision update ids will be split into + chunks and interleaved (details in `_omni_get_input_positions_tensor`). + + <|video_bos|><|VIDEO|><|video_eos|> => + <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|> + """ + + audio_token_id = thinker_config.audio_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr( + thinker_config.vision_config, "tokens_per_second", 25 + ) + + grid_t = video_grid_thw[0] + grid_h = video_grid_thw[1] + grid_w = video_grid_thw[2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = ( + torch.arange(grid_t) * video_second_per_grid_t * tokens_per_second + ).long() + t_index_split_chunk = split_list_into_ranges(t_index, t_ntoken_per_chunk) + + updates = [audio_start_token_id] + added_audio_len = 0 + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = ( + len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2) + ) + updates.extend([video_token_id] * vision_ntoken_per_chunk) + + audio_chunk_size = min(t_ntoken_per_chunk, audio_len - added_audio_len) + updates.extend(audio_chunk_size * [audio_token_id]) + added_audio_len += audio_chunk_size + if added_audio_len < audio_len: + updates.extend((audio_len - added_audio_len) * [audio_token_id]) + updates.extend([audio_end_token_id]) + + return updates + def _get_prompt_updates( self, mm_items: MultiModalDataItems, @@ -491,7 +547,7 @@ def get_replacement_qwen2_use_audio_in_video(item_idx: int): else: video_second_per_grid_t = 1.0 - return MRotaryEmbedding.omni_get_updates_use_audio_in_video( + return self.omni_get_updates_use_audio_in_video( thinker_config=thinker_config, audio_len=audio_num_features, video_grid_thw=video_grid_thw, @@ -524,7 +580,7 @@ def get_replacement_qwen2_use_audio_in_video(item_idx: int): def _apply_hf_processor_main( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], @@ -609,7 +665,7 @@ def _validate_and_reshape_mm_tensor( def _parse_and_validate_audio_input( self, **kwargs: object - ) -> Optional[Qwen2_5OmniAudioFeatureInputs]: + ) -> Qwen2_5OmniAudioFeatureInputs | None: input_audio_features = kwargs.pop("input_audio_features", None) audio_feature_lengths = kwargs.pop("audio_feature_lengths", None) feature_attention_mask = kwargs.pop("feature_attention_mask", None) @@ -637,7 +693,7 @@ def _parse_and_validate_audio_input( def _parse_and_validate_image_input( self, **kwargs: dict[str, Any], - ) -> Optional[Qwen2_5_VLImageInputs]: + ) -> Qwen2_5_VLImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -687,7 +743,7 @@ def _parse_and_validate_image_input( def _parse_and_validate_video_input( self, **kwargs: dict[str, Any], - ) -> Optional[Qwen2_5_VLVideoInputs]: + ) -> Qwen2_5_VLVideoInputs | None: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -808,6 +864,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( SupportsMultiModal, SupportsPP, SupportsLoRA, + SupportsMRoPE, Qwen2_5OmniConditionalGenerationMixin, ): hf_to_vllm_mapper = WeightsMapper( @@ -835,7 +892,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( } @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|vision_start|><|IMAGE|><|vision_end|>" if modality.startswith("video"): @@ -929,6 +986,215 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def get_language_model(self) -> torch.nn.Module: return self.language_model + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, + second_per_grid_ts: list[float] | None = None, + context_len: int = 0, + seq_len: int | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value (Qwen2.5-Omni version). + + Differences from MRotaryEmbedding: + 1. Add audio support (and related `audio_feature_lengths`). + 2. Add `use_audio_in_video` option to read audio from video inputs. + In this case, audio and vision position ids will be split into + chunks and interleaved. + + Example: + + (V_i are vision position ids, A_i are audio position ids) + + |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... + |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... + """ + + # TODO(fyabc): refactor and share more code with + # _vl_get_input_positions_tensor. + + thinker_config = hf_config.thinker_config + audio_token_id = thinker_config.audio_token_index + image_token_id = thinker_config.image_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + vision_start_token_id = thinker_config.vision_start_token_id + vision_end_token_id = thinker_config.vision_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr( + thinker_config.vision_config, "tokens_per_second", 25 + ) + + if isinstance(image_grid_thw, list): + image_grid_thw = torch.tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = torch.tensor(video_grid_thw) + + src_item = input_tokens + audio_seqlens = audio_feature_lengths + if not second_per_grid_ts: + second_per_grid_ts = [1] * video_grid_thw.shape[0] + audio_idx = 0 + video_idx = 0 + image_idx = 0 + new_src_item: list[int] = [] + llm_pos_ids_list: list[torch.Tensor] = [] + + idx = 0 + while idx < len(src_item): + new_src_item_len = len(new_src_item) + start_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) + if src_item[idx] not in [audio_token_id, video_token_id, image_token_id]: + if use_audio_in_video and idx > 0: + if ( + src_item[idx] == vision_end_token_id + and src_item[idx - 1] == audio_end_token_id + ): + # processing the <|audio_eos|> before <|vision_eos|> + start_idx -= 1 + elif ( + src_item[idx] == audio_start_token_id + and src_item[idx - 1] == vision_start_token_id + ): + # processing the <|audio_bos|> after <|vision_eos|> + start_idx -= 1 + new_src_item.append(src_item[idx]) + llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1) + llm_pos_ids_list.append(llm_pos_ids) + elif src_item[idx] == audio_token_id: + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + place_num = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1 + new_src_item.extend([audio_token_id] * place_num) + llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx + llm_pos_ids_list.append(llm_pos_ids) + audio_idx += 1 + elif src_item[idx] == image_token_id: + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() + llm_pos_ids = get_llm_pos_ids_for_vision( + start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = image_grid_thw[image_idx].prod() // ( + spatial_merge_size**2 + ) + new_src_item.extend([image_token_id] * vision_seqlen) + image_idx += 1 + elif src_item[idx] == video_token_id and not use_audio_in_video: + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t) + * second_per_grid_ts[video_idx] + * tokens_per_second + ).long() + llm_pos_ids = get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2 + ) + new_src_item.extend([video_token_id] * vision_seqlen) + video_idx += 1 + else: + # read audio from video + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + vision_seqlen = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2 + ) + grid_t = video_grid_thw[video_idx][0] + grid_h = video_grid_thw[video_idx][1] + grid_w = video_grid_thw[video_idx][2] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = ( + torch.arange(grid_t) + * second_per_grid_ts[video_idx] + * tokens_per_second + ).long() + t_index_split_chunk = split_list_into_ranges( + t_index, t_ntoken_per_chunk + ) + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 + pure_audio_len = place_num - 2 + added_audio_len = 0 + audio_llm_pos_ids_list: list[torch.Tensor] = [] + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = ( + len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2) + ) + new_src_item.extend([video_token_id] * vision_ntoken_per_chunk) + vision_llm_pos_ids_list = get_llm_pos_ids_for_vision( + start_idx, + video_idx, + spatial_merge_size, + t_chunk, + grid_hs, + grid_ws, + ).split(1, dim=1) + llm_pos_ids_list.extend(vision_llm_pos_ids_list) + new_src_item.extend( + min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) + * [audio_token_id] + ) + audio_start_idx = ( + start_idx + if len(audio_llm_pos_ids_list) == 0 + else audio_llm_pos_ids_list[-1][0].item() + 1 + ) + if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0: + audio_llm_pos_ids_list = ( + torch.arange( + min( + t_ntoken_per_chunk, pure_audio_len - added_audio_len + ) + ).expand(3, -1) + + audio_start_idx + ).split(1, dim=1) + else: + audio_llm_pos_ids_list = [] + added_audio_len += min( + t_ntoken_per_chunk, pure_audio_len - added_audio_len + ) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + if added_audio_len < pure_audio_len: + new_src_item.extend( + (pure_audio_len - added_audio_len) * [audio_token_id] + ) + audio_llm_pos_ids_list = ( + torch.arange(pure_audio_len - added_audio_len).expand(3, -1) + + llm_pos_ids_list[-1].max() + + 1 + ).split(1, dim=1) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + audio_idx += 1 + video_idx += 1 + # move to the next token + idx += len(new_src_item) - new_src_item_len + + llm_positions = torch.cat(llm_pos_ids_list, dim=1) + mrope_position_delta = ( + torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item) + ) + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: @@ -943,14 +1209,14 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: for modality in mm_input_by_modality: multimodal_input = mm_input_by_modality[modality] if modality == "image": - vision_embeddings = self._process_image_input(multimodal_input) - multimodal_embeddings += vision_embeddings + image_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += tuple(image_embeddings) if modality == "video": video_embeddings = self._process_video_input(multimodal_input) - multimodal_embeddings += video_embeddings + multimodal_embeddings += tuple(video_embeddings) if modality == "audio": audio_embeddings = self._process_audio_input(multimodal_input) - multimodal_embeddings += audio_embeddings + multimodal_embeddings += tuple(audio_embeddings) return multimodal_embeddings # TODO (ywang96): support overlapping modality embeddings so that @@ -958,9 +1224,9 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, *, - is_multimodal: Optional[torch.Tensor] = None, + is_multimodal: torch.Tensor | None = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: # This is to satisfy the type checker for each overload @@ -974,7 +1240,7 @@ def get_input_embeddings( handle_oov_mm_token=handle_oov_mm_token, ) - def get_multimodal_embeddings_v0(self, **kwargs: object) -> Optional[NestedTensors]: + def get_multimodal_embeddings_v0(self, **kwargs: object) -> NestedTensors | None: audio_input = self._parse_and_validate_audio_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs) @@ -999,10 +1265,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None @@ -1014,7 +1280,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 9cd83f61d921..c657b06d4355 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -26,15 +26,16 @@ # limitations under the License. """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" -from collections.abc import Iterable, Mapping, Sequence +import math +from collections.abc import Callable, Iterable, Mapping, Sequence from functools import lru_cache, partial -from typing import Annotated, Any, Callable, Literal, Optional, Union +from typing import Annotated, Any, Literal, TypeAlias import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from transformers import BatchFeature +from transformers import BatchFeature, PretrainedConfig from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, @@ -56,6 +57,7 @@ ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig @@ -72,13 +74,14 @@ from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.sequence import IntermediateTensors -from vllm.utils import is_pin_memory_available +from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import ( MultiModalEmbeddings, SupportsEagle3, SupportsLoRA, + SupportsMRoPE, SupportsMultiModal, SupportsMultiModalPruning, SupportsPP, @@ -97,7 +100,11 @@ init_vllm_registered_model, maybe_prefix, ) -from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model +from .vision import ( + conv3d_to_linear_weight, + get_vit_attn_backend, + run_dp_sharded_mrope_vision_model, +) logger = init_logger(__name__) @@ -115,7 +122,7 @@ class Qwen2_5_VLImagePixelInputs(TensorSchema): - pixel_values shape: (num_patches, num_channels * patch_size * patch_size) - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) - formatnum_channels * patch_size * patch_size + format. """ type: Literal["pixel_values"] @@ -160,9 +167,9 @@ class Qwen2_5_VLImageEmbeddingInputs(TensorSchema): ] -Qwen2_5_VLImageInputs = Union[ - Qwen2_5_VLImagePixelInputs, Qwen2_5_VLImageEmbeddingInputs -] +Qwen2_5_VLImageInputs: TypeAlias = ( + Qwen2_5_VLImagePixelInputs | Qwen2_5_VLImageEmbeddingInputs +) class Qwen2_5_VLVideoPixelInputs(TensorSchema): @@ -196,7 +203,7 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema): ] second_per_grid_ts: Annotated[ - Optional[torch.Tensor], + torch.Tensor | None, TensorShape("nv"), ] @@ -230,9 +237,9 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): ] -Qwen2_5_VLVideoInputs = Union[ - Qwen2_5_VLVideoPixelInputs, Qwen2_5_VLVideoEmbeddingInputs -] +Qwen2_5_VLVideoInputs: TypeAlias = ( + Qwen2_5_VLVideoPixelInputs | Qwen2_5_VLVideoEmbeddingInputs +) # === Vision Encoder === # @@ -244,7 +251,7 @@ def __init__( hidden_features: int, bias: bool = False, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ): @@ -300,7 +307,7 @@ def __init__( embed_dim: int, num_heads: int, projection_size: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, attn_backend: _Backend = _Backend.TORCH_SDPA, @@ -385,8 +392,8 @@ def forward( x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + max_seqlen: int | None = None, # Only used for Flash Attention + seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -395,7 +402,7 @@ def forward( q, k, v = self.split_qkv(x) batch_size = q.shape[1] - q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) + q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v)) if rotary_pos_emb is not None: # [2 * b, s, heads, head_dim] qk_concat = torch.cat([q, k], dim=0) @@ -465,8 +472,8 @@ def __init__( num_heads: int, mlp_hidden_dim: int, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, + norm_layer: Callable[[int], nn.Module] | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, attn_backend: _Backend = _Backend.TORCH_SDPA, @@ -502,8 +509,8 @@ def forward( x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + max_seqlen: int | None = None, # Only used for Flash Attention + seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: x_attn = self.attn( self.norm1(x), @@ -531,18 +538,15 @@ def __init__( self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = nn.Conv3d( - in_channels, + self.proj = ReplicatedLinear( + in_channels * math.prod(kernel_size), hidden_size, - kernel_size=kernel_size, - stride=kernel_size, bias=False, + return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) - x = self.proj(x).view(L, self.hidden_size) + x = self.proj(x) return x @@ -551,9 +555,9 @@ def __init__( self, d_model: int, context_dim: int, - norm_layer: Optional[Callable[[int], nn.Module]] = None, + norm_layer: Callable[[int], nn.Module] | None = None, spatial_merge_size: int = 2, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ) -> None: @@ -633,9 +637,10 @@ def __init__( self, vision_config: Qwen2_5_VLVisionConfig, norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() @@ -668,7 +673,9 @@ def __init__( use_upstream_fa = False self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype() + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) if ( self.attn_backend != _Backend.FLASH_ATTN @@ -814,7 +821,7 @@ def get_rope_by_thw(self, t, h, w): def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor, - ) -> tuple[Optional[int], Optional[list[int]]]: + ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None if ( self.attn_backend == _Backend.FLASH_ATTN @@ -946,6 +953,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params: set[str] = set() for name, loaded_weight in weights: + if name.endswith("patch_embed.proj.weight"): + loaded_weight = conv3d_to_linear_weight(loaded_weight) + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -1053,6 +1063,7 @@ class Qwen2_5_VLForConditionalGeneration( SupportsQuant, SupportsEagle3, SupportsMultiModalPruning, + SupportsMRoPE, ): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], @@ -1073,8 +1084,133 @@ class Qwen2_5_VLForConditionalGeneration( supports_encoder_tp_data = True + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, + second_per_grid_ts: list[float], + context_len: int = 0, + seq_len: int | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0) + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id + ).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + video_second_per_grid_t = 0.0 + if remain_images > 0: + try: + ed_image = input_tokens.index(image_token_id, st) + except ValueError: + ed_image = len(input_tokens) + 1 + else: + ed_image = len(input_tokens) + 1 + if remain_videos > 0: + try: + ed_video = input_tokens.index(video_token_id, st) + except ValueError: + ed_video = len(input_tokens) + 1 + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_second_per_grid_t = 1.0 + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + * video_second_per_grid_t + * tokens_per_second + ) + .long() + .flatten() + ) + + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|vision_start|><|image_pad|><|vision_end|>" if modality.startswith("video"): @@ -1098,12 +1234,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if multimodal_config.get_limit_per_prompt( "image" ) or multimodal_config.get_limit_per_prompt("video"): + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.visual = Qwen2_5_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=self.quant_config, prefix=maybe_prefix(prefix, "visual"), use_data_parallel=self.use_data_parallel, + attn_backend_override=attn_backend_override, ) else: self.visual = None @@ -1145,7 +1287,7 @@ def _validate_and_reshape_mm_tensor( def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[Qwen2_5_VLImageInputs]: + ) -> Qwen2_5_VLImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1183,7 +1325,7 @@ def _parse_and_validate_image_input( def _parse_and_validate_video_input( self, **kwargs: object - ) -> Optional[Qwen2_5_VLVideoInputs]: + ) -> Qwen2_5_VLVideoInputs | None: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -1458,29 +1600,29 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: for modality in mm_input_by_modality: multimodal_input = mm_input_by_modality[modality] if modality == "image": - vision_embeddings = self._process_image_input(multimodal_input) + image_embeddings = self._process_image_input(multimodal_input) if self.is_multimodal_pruning_enabled: - vision_embeddings = self._postprocess_image_embeds_evs( - vision_embeddings, multimodal_input + image_embeddings = self._postprocess_image_embeds_evs( + image_embeddings, multimodal_input ) - multimodal_embeddings += vision_embeddings + multimodal_embeddings += tuple(image_embeddings) if modality == "video": video_embeddings = self._process_video_input(multimodal_input) if self.is_multimodal_pruning_enabled: video_embeddings = self._postprocess_video_embeds_evs( video_embeddings, multimodal_input ) - multimodal_embeddings += video_embeddings + multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: """Run forward pass for Qwen2.5-VL. Args: @@ -1506,7 +1648,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index e61a730f97bb..553fdc4a9e17 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -24,7 +24,7 @@ """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal, TypeAlias import torch import torch.nn as nn @@ -78,7 +78,7 @@ class Qwen2AudioFeatureInputs(TensorSchema): type: Literal["audio_features"] input_features: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("na", "nmb", 3000), ] @@ -105,7 +105,7 @@ class Qwen2AudioEmbeddingInputs(TensorSchema): ] -Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs] +Qwen2AudioInputs: TypeAlias = Qwen2AudioFeatureInputs | Qwen2AudioEmbeddingInputs # === Audio Encoder === # @@ -140,7 +140,7 @@ def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor: assert isinstance(feature_extractor, WhisperFeatureExtractor) return feature_extractor - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"audio": None} @@ -157,7 +157,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: feature_extractor = self.info.get_feature_extractor() @@ -185,8 +185,8 @@ def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]): class Qwen2AudioMultiModalDataParser(MultiModalDataParser): def _parse_audio_data( self, - data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]], - ) -> Optional[ModalityDataItems[Any, Any]]: + data: dict[str, torch.Tensor] | ModalityData[AudioItem], + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, @@ -314,7 +314,7 @@ def get_replacement_qwen2_audio(item_idx: int): ) class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("audio"): return f"Audio {i}: <|audio_bos|><|AUDIO|><|audio_eos|>" @@ -358,7 +358,7 @@ def _validate_and_reshape_mm_tensor( def _parse_and_validate_audio_input( self, **kwargs: object - ) -> Optional[Qwen2AudioInputs]: + ) -> Qwen2AudioInputs | None: input_features = kwargs.pop("input_features", None) audio_embeds = kwargs.pop("audio_embeds", None) feature_attention_mask = kwargs.pop("feature_attention_mask", None) @@ -395,7 +395,7 @@ def _parse_and_validate_audio_input( def _process_audio_input( self, audio_input: Qwen2AudioInputs - ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + ) -> torch.Tensor | tuple[torch.Tensor, ...]: if audio_input["type"] == "audio_embeds": audio_embeds = audio_input["audio_embeds"] return tuple(audio_embeds) @@ -471,10 +471,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None @@ -486,7 +486,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 61b203a08349..c03bd6a3c6d7 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -27,7 +27,7 @@ from collections.abc import Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch import torch.nn.functional as F @@ -40,7 +40,7 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -77,8 +77,9 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, reduce_results: bool = True, + expert_gate: torch.nn.Linear | None = None, prefix: str = "", ) -> None: super().__init__() @@ -102,19 +103,24 @@ def __init__( f"Unsupported activation: {hidden_act}. Only silu is supported for now." ) self.act_fn = SiluAndMul() + self.expert_gate = expert_gate def forward(self, x): gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x + out = self.act_fn(gate_up) + out, _ = self.down_proj(out) + + if self.expert_gate is not None: + out = F.sigmoid(self.expert_gate(x)) * out + + return out class Qwen2MoeSparseMoeBlock(nn.Module): def __init__( self, config: Qwen2MoeConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -126,17 +132,6 @@ def __init__( f"the number of experts {config.num_experts}." ) - self.experts = FusedMoE( - num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts", - ) - self.gate = ReplicatedLinear( config.hidden_size, config.num_experts, @@ -144,39 +139,47 @@ def __init__( quant_config=None, prefix=f"{prefix}.gate", ) + + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) + if config.shared_expert_intermediate_size > 0: self.shared_expert = Qwen2MoeMLP( hidden_size=config.hidden_size, intermediate_size=config.shared_expert_intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=self.experts.must_reduce_shared_expert_outputs(), + reduce_results=False, + expert_gate=self.shared_expert_gate, prefix=f"{prefix}.shared_expert", ) else: self.shared_expert = None - self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) + + self.experts = SharedFusedMoE( + shared_experts=self.shared_expert, + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) - shared_output = None - if self.shared_expert is not None: - shared_output = self.shared_expert(hidden_states) - if self.shared_expert_gate is not None: - shared_output = ( - F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output - ) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits ) - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output + if self.shared_expert is not None: + final_hidden_states = final_hidden_states[0] + final_hidden_states[1] if self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states @@ -192,12 +195,12 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", - dual_chunk_attention_config: Optional[dict[str, Any]] = None, + dual_chunk_attention_config: dict[str, Any] | None = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -282,8 +285,8 @@ class Qwen2MoeDecoderLayer(nn.Module): def __init__( self, config: Qwen2MoeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -336,7 +339,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> torch.Tensor: # Self Attention if residual is None: @@ -393,9 +396,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -418,7 +421,7 @@ def forward( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return FusedMoE.make_expert_params_mapping( + return SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", @@ -531,11 +534,7 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): "q_proj", "k_proj", "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + ] } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -544,6 +543,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config + # Only perform the following mapping when Qwen2MoeMLP exists + if ( + getattr(config, "mlp_only_layers", []) + or config.shared_expert_intermediate_size > 0 + ): + self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"] + self.model = Qwen2MoeModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) @@ -567,9 +573,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -578,7 +584,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 75ed95477f78..e2ba0e262cf7 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -8,7 +8,6 @@ """Inference-only Qwen2-RM model compatible with HuggingFace weights.""" from collections.abc import Iterable -from typing import Optional, Union import torch from torch import nn @@ -83,9 +82,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -108,7 +107,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): assert pooler_config is not None self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}, + {"token_classify": Pooler.for_token_classify(pooler_config)} ) @@ -121,4 +120,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({"encode": Pooler.for_encode(pooler_config)}) + self.pooler = DispatchPooler( + {"token_classify": Pooler.for_token_classify(pooler_config)} + ) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index cb1bf3825c74..61f7970d56f6 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -25,9 +25,10 @@ # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" -from collections.abc import Iterable, Mapping, Sequence +import math +from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial -from typing import Annotated, Any, Callable, Literal, Optional, Union +from typing import Annotated, Any, Literal, TypeAlias import torch import torch.nn as nn @@ -53,7 +54,11 @@ from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor.layers.activation import QuickGELU -from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding.common import ( dispatch_rotary_emb_function, @@ -100,7 +105,11 @@ init_vllm_registered_model, maybe_prefix, ) -from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model +from .vision import ( + conv3d_to_linear_weight, + get_vit_attn_backend, + run_dp_sharded_mrope_vision_model, +) logger = init_logger(__name__) @@ -167,7 +176,7 @@ class Qwen2VLImageEmbeddingInputs(TensorSchema): ] -Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, Qwen2VLImageEmbeddingInputs] +Qwen2VLImageInputs: TypeAlias = Qwen2VLImagePixelInputs | Qwen2VLImageEmbeddingInputs class Qwen2VLVideoPixelInputs(TensorSchema): @@ -228,7 +237,7 @@ class Qwen2VLVideoEmbeddingInputs(TensorSchema): ] -Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs, Qwen2VLVideoEmbeddingInputs] +Qwen2VLVideoInputs: TypeAlias = Qwen2VLVideoPixelInputs | Qwen2VLVideoEmbeddingInputs # === Vision Encoder === # @@ -239,7 +248,7 @@ def __init__( in_features: int, hidden_features: int, act_layer: type[nn.Module] = QuickGELU, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ): @@ -317,9 +326,10 @@ def __init__( embed_dim: int, num_heads: int, projection_size: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -355,6 +365,7 @@ def __init__( self.attn_backend = get_vit_attn_backend( head_size=self.hidden_size_per_attention_head, dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) self.use_upstream_fa = False @@ -362,6 +373,7 @@ def __init__( maybe_get_vit_flash_attn_backend( self.attn_backend, self.use_upstream_fa, + attn_backend_override=attn_backend_override, ) ) @@ -413,8 +425,8 @@ def forward( x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + max_seqlen: int | None = None, # Only used for Flash Attention + seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: # [s, b, c] --> [s, b, 3 * head * head_dim] x, _ = self.qkv(x) @@ -423,7 +435,7 @@ def forward( q, k, v = self.split_qkv(x) batch_size = q.shape[1] - q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) + q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v)) if rotary_pos_emb is not None: # [2 * b, s, heads, head_dim] qk_concat = torch.cat([q, k], dim=0) @@ -493,10 +505,11 @@ def __init__( num_heads: int, mlp_ratio: float, act_layer: type[nn.Module] = QuickGELU, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, + norm_layer: Callable[[int], nn.Module] | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -512,6 +525,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.attn", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) self.mlp = Qwen2VisionMLP( dim, @@ -527,8 +541,8 @@ def forward( x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + max_seqlen: int | None = None, # Only used for Flash Attention + seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: x = x + self.attn( self.norm1(x), @@ -556,18 +570,15 @@ def __init__( self.embed_dim = embed_dim kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = nn.Conv3d( - in_channels, + self.proj = ReplicatedLinear( + in_channels * math.prod(kernel_size), embed_dim, - kernel_size=kernel_size, - stride=kernel_size, bias=False, + return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) - x = self.proj(x).view(L, self.embed_dim) + x = self.proj(x) return x @@ -576,9 +587,9 @@ def __init__( self, d_model: int, context_dim: int, - norm_layer: Optional[Callable[[int], nn.Module]] = None, + norm_layer: Callable[[int], nn.Module] | None = None, spatial_merge_size: int = 2, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ) -> None: @@ -659,9 +670,10 @@ def __init__( self, vision_config: Qwen2VLVisionConfig, norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() @@ -703,6 +715,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.blocks.{layer_idx}", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) for layer_idx in range(depth) ] @@ -716,7 +729,9 @@ def __init__( use_data_parallel=use_data_parallel, ) self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype() + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( torch.get_default_dtype() @@ -766,7 +781,7 @@ def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor: def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor - ) -> tuple[Optional[int], Optional[list[int]]]: + ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None if ( self.attn_backend == _Backend.FLASH_ATTN @@ -826,6 +841,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params: set[str] = set() for name, loaded_weight in weights: + if name.endswith("patch_embed.proj.weight"): + loaded_weight = conv3d_to_linear_weight(loaded_weight) + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -889,8 +907,8 @@ def __init__(self, spatial_merge_size: int, *args, **kwargs): def _parse_image_data( self, - data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], - ) -> Optional[ModalityDataItems[Any, Any]]: + data: dict[str, torch.Tensor] | ModalityData[ImageItem], + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, @@ -903,8 +921,8 @@ def _parse_image_data( def _parse_video_data( self, - data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], - ) -> Optional[ModalityDataItems[Any, Any]]: + data: dict[str, torch.Tensor] | ModalityData[VideoItem], + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, @@ -930,7 +948,7 @@ def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor: def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor: return self.get_hf_processor(**kwargs).image_processor - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None, "video": None} def get_mm_max_tokens_per_item( @@ -949,7 +967,7 @@ def _get_vision_info( image_height: int, num_frames: int = 1, do_resize: bool = True, - image_processor: Optional[Qwen2VLImageProcessor], + image_processor: Qwen2VLImageProcessor | None, ) -> tuple[ImageSize, int]: if image_processor is None: image_processor = self.get_image_processor() @@ -990,7 +1008,7 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - image_processor: Optional[Qwen2VLImageProcessor], + image_processor: Qwen2VLImageProcessor | None, ) -> int: _, num_image_tokens = self._get_vision_info( image_width=image_width, @@ -1006,7 +1024,7 @@ def get_num_video_tokens( image_width: int, image_height: int, num_frames: int, - image_processor: Optional[Qwen2VLImageProcessor], + image_processor: Qwen2VLImageProcessor | None, ) -> int: _, num_video_tokens = self._get_vision_info( image_width=image_width, @@ -1100,7 +1118,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -1207,12 +1225,12 @@ def get_mrope_input_positions( self, input_tokens: list[int], hf_config: PretrainedConfig, - image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], - video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], - second_per_grid_ts: Optional[list[float]] = None, + image_grid_thw: list[list[int]] | torch.Tensor | None, + video_grid_thw: list[list[int]] | torch.Tensor | None, + second_per_grid_ts: list[float] | None = None, context_len: int = 0, - seq_len: Optional[int] = None, - audio_feature_lengths: Optional[torch.Tensor] = None, + seq_len: int | None = None, + audio_feature_lengths: torch.Tensor | None = None, use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: """Get M-RoPE input positions for Qwen2-VL model.""" @@ -1335,7 +1353,7 @@ def get_mrope_input_positions( return llm_positions, mrope_position_delta @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|vision_start|><|image_pad|><|vision_end|>" if modality.startswith("video"): @@ -1356,12 +1374,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if multimodal_config.get_limit_per_prompt( "image" ) or multimodal_config.get_limit_per_prompt("video"): + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.visual = Qwen2VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), use_data_parallel=self.use_data_parallel, + attn_backend_override=attn_backend_override, ) else: self.visual = None @@ -1396,7 +1420,7 @@ def _validate_and_reshape_mm_tensor( def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[Qwen2VLImageInputs]: + ) -> Qwen2VLImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1434,7 +1458,7 @@ def _parse_and_validate_image_input( def _parse_and_validate_video_input( self, **kwargs: object - ) -> Optional[Qwen2VLVideoInputs]: + ) -> Qwen2VLVideoInputs | None: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -1561,12 +1585,12 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: for modality in modalities: if modality == "images": image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) - multimodal_embeddings += vision_embeddings + image_embeddings = self._process_image_input(image_input) + multimodal_embeddings += tuple(image_embeddings) if modality == "videos": video_input = modalities["videos"] video_embeddings = self._process_video_input(video_input) - multimodal_embeddings += video_embeddings + multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings @@ -1574,10 +1598,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: """Run forward pass for Qwen2-VL. Args: @@ -1606,7 +1630,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -1634,7 +1658,7 @@ class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor): class Tarsier2ImageProcessor(Qwen2VLImageProcessor): def __init__( self, - size: Optional[dict[str, int]] = None, + size: dict[str, int] | None = None, **kwargs, ) -> None: if size is not None and "min_pixels" in size and "max_pixels" in size: diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index bcd4968ba5c4..563d3cc23d72 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -24,7 +24,7 @@ """Inference-only Qwen3 model compatible with HuggingFace weights.""" from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -58,16 +58,16 @@ def __init__( num_heads: int, num_kv_heads: int, max_position: int = 4096 * 32, - head_dim: Optional[int] = None, + head_dim: int | None = None, rms_norm_eps: float = 1e-06, qkv_bias: bool = False, rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[tuple] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + rope_scaling: tuple | None = None, prefix: str = "", attn_type: str = AttentionType.DECODER, - dual_chunk_attention_config: Optional[dict[str, Any]] = None, + dual_chunk_attention_config: dict[str, Any] | None = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -160,8 +160,8 @@ class Qwen3DecoderLayer(nn.Module): def __init__( self, config: Qwen3Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -214,7 +214,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -315,9 +315,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -326,7 +326,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 34b5af846493..8452d7b04f5c 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -26,7 +26,7 @@ import typing from collections.abc import Callable, Iterable from itertools import islice -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -64,7 +64,7 @@ from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors -from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP +from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, PPMissingLayer, @@ -84,7 +84,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, reduce_results: bool = True, prefix: str = "", ) -> None: @@ -215,15 +215,15 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - head_dim: Optional[int] = None, + head_dim: int | None = None, rms_norm_eps: float = 1e-06, qkv_bias: bool = False, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", - dual_chunk_attention_config: Optional[dict[str, Any]] = None, + dual_chunk_attention_config: dict[str, Any] | None = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -374,7 +374,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -422,6 +422,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size ) + # Track layers for auxiliary hidden state outputs (EAGLE3) + self.aux_hidden_state_layers: tuple[int, ...] = () def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -430,9 +432,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -443,13 +445,29 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in islice(self.layers, self.start_layer, self.end_layer): + + aux_hidden_states = [] + for layer_idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer), + start=self.start_layer, + ): + # Collect auxiliary hidden states if specified + if layer_idx in self.aux_hidden_state_layers: + aux_hidden_state = ( + hidden_states + residual if residual is not None else hidden_states + ) + aux_hidden_states.append(aux_hidden_state) hidden_states, residual = layer(positions, hidden_states, residual) + if not get_pp_group().is_last_rank: return IntermediateTensors( {"hidden_states": hidden_states, "residual": residual} ) hidden_states, _ = self.norm(hidden_states, residual) + + # Return auxiliary hidden states if collected + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: @@ -606,17 +624,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded_params -class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExperts): +class Qwen3MoeForCausalLM( + nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts +): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + ] } fall_back_to_pt_during_load = False @@ -627,6 +643,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config + # Only perform the following mapping when Qwen3MoeMLP exists + if getattr(config, "mlp_only_layers", []): + self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"] self.model = Qwen3MoeModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) @@ -702,6 +721,13 @@ def update_physical_experts_metadata( moe.n_redundant_experts = self.num_redundant_experts moe.experts.update_expert_map() + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -709,9 +735,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -720,7 +746,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index cea3faf45a14..e81ad5f68d8f 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -4,10 +4,8 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional import torch -import torch.nn.functional as F from einops import rearrange from torch import nn from transformers.activations import ACT2FN @@ -36,7 +34,7 @@ chunk_gated_delta_rule, fused_recurrent_gated_delta_rule, ) -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import GemmaRMSNorm as Qwen3NextRMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -73,7 +71,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from .interfaces import ( @@ -136,20 +134,6 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): self.physical_expert_start + self.n_local_physical_experts ) - self.experts = FusedMoE( - num_experts=self.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts", - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts, - is_sequence_parallel=self.is_sequence_parallel, - ) - self.gate = ReplicatedLinear( config.hidden_size, config.num_experts, @@ -158,18 +142,35 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.gate", ) + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) + if config.shared_expert_intermediate_size > 0: self.shared_expert = Qwen3NextMLP( hidden_size=config.hidden_size, intermediate_size=config.shared_expert_intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=self.experts.must_reduce_shared_expert_outputs(), + reduce_results=False, + expert_gate=self.shared_expert_gate, prefix=f"{prefix}.shared_expert", ) else: self.shared_expert = None - self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) + + self.experts = SharedFusedMoE( + shared_experts=self.shared_expert, + num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -180,22 +181,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.is_sequence_parallel: hidden_states = sequence_parallel_chunk(hidden_states) - shared_output = None - if self.shared_expert is not None: - shared_output = self.shared_expert(hidden_states) - if self.shared_expert_gate is not None: - shared_output = ( - F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output - ) - # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits ) - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output + if self.shared_expert is not None: + final_hidden_states = final_hidden_states[0] + final_hidden_states[1] if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( @@ -239,10 +232,10 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: def __init__( self, config: Qwen3NextConfig, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - speculative_config: Optional[SpeculativeConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + speculative_config: SpeculativeConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -332,7 +325,6 @@ def __init__( self.A_log = nn.Parameter( torch.empty( divide(self.num_v_heads, self.tp_size), - dtype=torch.float32, ) ) @@ -345,7 +337,7 @@ def __init__( group_size=None, norm_before_gate=True, device=current_platform.current_device(), - dtype=config.torch_dtype, + dtype=config.dtype, ) self.out_proj = RowParallelLinear( @@ -430,7 +422,7 @@ def rearrange_mixed_qkv(self, mixed_qkv): (query, key), ) value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) - return query, key, value + return query.contiguous(), key.contiguous(), value.contiguous() def forward( self, @@ -462,7 +454,8 @@ def _forward( spec_query_start_loc = attn_metadata.spec_query_start_loc non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc spec_sequence_masks = attn_metadata.spec_sequence_masks - spec_token_masks = attn_metadata.spec_token_masks + spec_token_indx = attn_metadata.spec_token_indx + non_spec_token_indx = attn_metadata.non_spec_token_indx spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 self_kv_cache = self.kv_cache[forward_context.virtual_engine] @@ -470,8 +463,6 @@ def _forward( ssm_state = self_kv_cache[1] num_actual_tokens = attn_metadata.num_actual_tokens num_accepted_tokens = attn_metadata.num_accepted_tokens - if spec_token_masks is not None: - spec_token_masks = spec_token_masks[:num_actual_tokens] # 1. Set up dimensions for reshapes later projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states[:num_actual_tokens]) @@ -494,8 +485,8 @@ def _forward( mixed_qkv_spec = mixed_qkv mixed_qkv_non_spec = None else: - mixed_qkv_spec = mixed_qkv[spec_token_masks] - mixed_qkv_non_spec = mixed_qkv[~spec_token_masks] + mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) + mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx) else: mixed_qkv_spec = None mixed_qkv_non_spec = mixed_qkv @@ -565,10 +556,10 @@ def _forward( g_non_spec = None beta_non_spec = None else: - g_spec = g[:, spec_token_masks] - beta_spec = beta[:, spec_token_masks] - g_non_spec = g[:, ~spec_token_masks] - beta_non_spec = beta[:, ~spec_token_masks] + g_spec = g.index_select(1, spec_token_indx) + beta_spec = beta.index_select(1, spec_token_indx) + g_non_spec = g.index_select(1, non_spec_token_indx) + beta_non_spec = beta.index_select(1, non_spec_token_indx) else: g_spec = None beta_spec = None @@ -645,8 +636,9 @@ def _forward( dtype=core_attn_out_non_spec.dtype, device=core_attn_out_non_spec.device, ) - core_attn_out[:, spec_token_masks] = core_attn_out_spec - core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec + core_attn_out.index_copy_(1, spec_token_indx, core_attn_out_spec) + core_attn_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) + elif spec_sequence_masks is not None: core_attn_out = core_attn_out_spec else: @@ -667,9 +659,9 @@ class Qwen3NextAttention(nn.Module): def __init__( self, config: Qwen3NextConfig, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -854,7 +846,7 @@ def __init__( 1, 1, config.hidden_size, - dtype=config.torch_dtype, + dtype=config.dtype, ), ) self.ffn_layer_scale = torch.nn.Parameter( @@ -862,14 +854,14 @@ def __init__( 1, 1, config.hidden_size, - dtype=config.torch_dtype, + dtype=config.dtype, ), ) def forward( self, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, positions: torch.Tensor = None, **kwargs: object, ): @@ -977,8 +969,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -1008,7 +1000,7 @@ def forward( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return FusedMoE.make_expert_params_mapping( + return SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", @@ -1150,7 +1142,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Set MoE hyperparameters self.expert_weights = [] - self.moe_layers: list[FusedMoE] = [] + self.moe_layers: list[SharedFusedMoE] = [] example_layer = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): @@ -1213,8 +1205,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ): hidden_states = self.model( @@ -1257,7 +1249,7 @@ def get_mamba_state_shape_from_config( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.logits_processor(self.lm_head, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py index 828931716c8f..a447484ae82a 100644 --- a/vllm/model_executor/models/qwen3_next_mtp.py +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -3,7 +3,6 @@ """Inference-only Qwen3Next MTP model.""" from collections.abc import Iterable -from typing import Optional import torch from torch import nn @@ -108,8 +107,8 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, spec_step_idx: int = 0, ) -> torch.Tensor: if get_pp_group().is_first_rank: @@ -275,8 +274,8 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ): hidden_states = self.model( @@ -288,7 +287,7 @@ def compute_logits( self, hidden_states: torch.Tensor, spec_step_idx: int = 0, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.logits_processor(self.lm_head, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py new file mode 100755 index 000000000000..89ce0068fb1a --- /dev/null +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -0,0 +1,1743 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3-Omni-Moe model (thinker part).""" + +import math +from collections.abc import Callable, Iterable, Mapping, Sequence +from functools import partial +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging.version import Version +from transformers import PretrainedConfig +from transformers import __version__ as TRANSFORMERS_VERSION +from transformers.feature_extraction_utils import BatchFeature +from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import ( + Qwen3OmniMoeConfig, + Qwen3OmniMoeThinkerConfig, +) +from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeAudioEncoder, +) +from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import ( + Qwen3OmniMoeProcessor, +) +from transformers.models.whisper import WhisperFeatureExtractor + +from vllm.attention.backends.registry import _Backend +from vllm.attention.layer import check_upstream_fa_availability +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.qwen2_audio import ( + Qwen2AudioFeatureInputs, + Qwen2AudioProcessingInfo, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalKwargsItems +from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + MultiModalPromptUpdates, + PlaceholderFeaturesInfo, + PromptReplacement, + PromptUpdate, +) +from vllm.sequence import IntermediateTensors + +from .interfaces import ( + MultiModalEmbeddings, + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, +) +from .qwen2_5_omni_thinker import ( + Qwen2_5OmniConditionalGenerationMixin, + Qwen2_5OmniThinkerDummyInputsBuilder, + Qwen2_5OmniThinkerMultiModalProcessor, + Qwen2_5OmniThinkerProcessingInfo, +) +from .qwen2_5_vl import ( + Qwen2_5_VisionAttention, + Qwen2_5_VisionRotaryEmbedding, + Qwen2_5_VLProcessingInfo, +) +from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + _merge_multimodal_embeddings, + maybe_prefix, +) +from .vision import ( + conv3d_to_linear_weight, + get_llm_pos_ids_for_vision, + get_vit_attn_backend, +) + +try: + import flash_attn +except (ImportError, ModuleNotFoundError): + flash_attn = None + +logger = init_logger(__name__) + + +def _get_feat_extract_output_lengths(input_lengths: torch.Tensor): + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ( + ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + ) + return feat_lengths, output_lengths + + +class Qwen3_VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + hidden_size: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.hidden_size = hidden_size + + kernel_size = (temporal_patch_size, patch_size, patch_size) + self.proj = ReplicatedLinear( + in_channels * math.prod(kernel_size), + hidden_size, + bias=True, + return_bias=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + L, C = x.shape + x = self.proj(x) + return x + + +class Qwen3_VisionMLP(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: int, + bias: bool = False, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.linear_fc1 = ColumnParallelLinear( + in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.linear_fc1", + ) + self.linear_fc2 = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.linear_fc2", + ) + self.act_fn = act_fn + + def forward(self, x: torch.Tensor): + mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return mlp_output + + +class Qwen3_VisionBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + norm_layer: Callable[[int], nn.Module] | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + self.attn = Qwen2_5_VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + self.mlp = Qwen3_VisionMLP( + dim, + mlp_hidden_dim, + act_fn=act_fn, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: int | None = None, # Only used for Flash Attention + seqlens: list[int] | None = None, # Only used for xFormers + ) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + + x = x + self.mlp(self.norm2(x)) + return x + + +class Qwen3_VisionPatchMerger(nn.Module): + def __init__( + self, + d_model: int, + context_dim: int, + norm_layer: Callable[[int], nn.Module] | None = None, + spatial_merge_size: int = 2, + use_postshuffle_norm: bool = False, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + + self.use_postshuffle_norm = use_postshuffle_norm + if self.use_postshuffle_norm: + context_dim = self.hidden_size + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.use_postshuffle_norm = use_postshuffle_norm + self.ln_q = norm_layer( + self.hidden_size if use_postshuffle_norm else context_dim + ) + self.mlp = nn.ModuleList( + [ + ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0", + ), + nn.GELU(), + RowParallelLinear( + self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2", + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_postshuffle_norm: + x = self.ln_q(x.view(-1, self.hidden_size)) + else: + x = self.ln_q(x).view(-1, self.hidden_size) + + mlp_fc1, mlp_act, mlp_fc2 = self.mlp + x_parallel, _ = mlp_fc1(x) + x_parallel = mlp_act(x_parallel) + out, _ = mlp_fc2(x_parallel) + return out + + +class Qwen3Omni_VisionTransformer(nn.Module): + def __init__( + self, + vision_config, + norm_eps: float = 1e-6, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + attn_backend_override: _Backend | None = None, + ) -> None: + super().__init__() + self.hidden_size = vision_config.hidden_size + self.num_heads = vision_config.num_heads + self.image_size = vision_config.image_size + self.patch_size = vision_config.patch_size + self.spatial_merge_size = vision_config.spatial_merge_size + self.spatial_merge_unit = self.spatial_merge_size**2 + self.temporal_patch_size = vision_config.temporal_patch_size + self.num_grid_per_side = self.image_size // self.patch_size + self.apply_vit_abs_pos_embed = vision_config.apply_vit_abs_pos_embed + self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes + + self.patch_embed = Qwen3_VisionPatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=vision_config.in_channels, + hidden_size=self.hidden_size, + ) + + # vit pos embeding, TODO: spatial_patch_size vs patch_size + if self.apply_vit_abs_pos_embed: + self.pos_embed = nn.Embedding(self.num_grid_per_side**2, self.hidden_size) + else: + self.pos_embed = nn.Parameter( + torch.empty([1, self.num_grid_per_side**2, self.hidden_size]) + ) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [ + Qwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + ) + for layer_idx in range(vision_config.depth) + ] + ) + self.merger = Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + prefix=f"{prefix}.merger", + ) + if self.deepstack_visual_indexes is not None: + self.merger_list = nn.ModuleList( + [ + Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=True, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.merger_list.{layer_idx}", + ) + for layer_idx in range(len(self.deepstack_visual_indexes)) + ] + ) + + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, + ) + if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + torch.get_default_dtype() + ): + self.attn_backend = _Backend.FLASH_ATTN + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: + num_grid_per_side = self.num_grid_per_side + m_size = self.spatial_merge_size + hidden_dim = self.pos_embed.embedding_dim + + outputs = [] + for t, h, w in grid_thw: + h_idxs = torch.linspace( + 0, num_grid_per_side - 1, h, dtype=torch.float32, device=self.device + ) + w_idxs = torch.linspace( + 0, num_grid_per_side - 1, w, dtype=torch.float32, device=self.device + ) + + h_floor = h_idxs.to(torch.long) + w_floor = w_idxs.to(torch.long) + h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) + w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) + + dh = h_idxs - h_floor + dw = w_idxs - w_floor + + # Create meshgrid view for all h, w vars + dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij") + h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij") + h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij") + h_floor_grid_idx = h_floor_grid * num_grid_per_side + h_ceil_grid_idx = h_ceil_grid * num_grid_per_side + + # original computation of weights + # w00 = (1 - dh_grid) * (1 - dw_grid) + # w01 = (1 - dh_grid) * dw_grid + # w10 = dh_grid * (1 - dw_grid) + # w11 = dh_grid * dw_grid + # we reuse w11 here to avoid duplicate + # dh_grid * dw_grid computation + w11 = dh_grid * dw_grid + w10 = dh_grid - w11 + w01 = dw_grid - w11 + w00 = 1 - dh_grid - dw_grid + w11 + + idx00 = h_floor_grid_idx + w_floor_grid + idx01 = h_floor_grid_idx + w_ceil_grid + idx10 = h_ceil_grid_idx + w_floor_grid + idx11 = h_ceil_grid_idx + w_ceil_grid + + indices = torch.stack([idx00, idx01, idx10, idx11], dim=0).reshape(4, -1) + weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) + weights = weights.to(dtype=self.dtype, device=self.device) + + embeds = self.pos_embed(indices) + weighted_embeds = embeds * weights + p0, p1, p2, p3 = weighted_embeds.unbind(dim=0) + combined = p0 + p1 + p2 + p3 + + combined = combined.view(h * w, hidden_dim) + repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous() + repeated = repeated.view( + t, h // m_size, m_size, w // m_size, m_size, hidden_dim + ) + repeated = repeated.permute(0, 1, 3, 2, 4, 5).reshape(-1, hidden_dim) + outputs.append(repeated) + + return torch.cat(outputs, dim=0) + + def compute_attn_mask_seqlen( + self, + cu_seqlens: torch.Tensor, + ) -> tuple[int | None, list[int] | None]: + max_seqlen, seqlens = None, None + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + + def forward( + self, + x: torch.Tensor, + grid_thw: list[list[int]], + ) -> torch.Tensor: + hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = self.patch_embed(hidden_states) + + if self.apply_vit_abs_pos_embed: + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum( + dim=0, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + hidden_states = hidden_states.unsqueeze(1) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + + hidden_states_list = [] + deepstack_visual_indexes = self.deepstack_visual_indexes + + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + if ( + deepstack_visual_indexes is not None + and layer_num in deepstack_visual_indexes + ): + hidden_states_list.append(hidden_states) + + hidden_states = self.merger(hidden_states) + + # processing deepstack + if deepstack_visual_indexes is not None: + processed_hidden_states_list = [hidden_states] + for idx, x in enumerate(hidden_states_list): + x = self.merger_list[idx](x) + processed_hidden_states_list.append(x) + # we cat the original visual features and deepstack features + # along the feature dim + hidden_states = torch.cat( + processed_hidden_states_list, dim=1 + ) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("attn.qkv.", "attn.q.", "q"), + ("attn.qkv.", "attn.k.", "k"), + ("attn.qkv.", "attn.v.", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if name.endswith("patch_embed.proj.weight"): + loaded_weight = conv3d_to_linear_weight(loaded_weight) + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + "deepstack_input_embeds": 0, + } +) +class Qwen3MoeLLMModel(Qwen3MoeModel): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + self.deepstack_multiscale_layer_start = 1 + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + deepstack_input_embeds: IntermediateTensors | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer_idx, layer in enumerate( + self.layers[self.start_layer : self.end_layer] + ): + layer_idx = layer_idx + self.start_layer + + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if deepstack_input_embeds is not None and layer_idx in range( + 0, len(deepstack_input_embeds) + ): + hidden_states = ( + hidden_states + + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"] + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super(Qwen3MoeForCausalLM, self).__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Qwen3MoeLLMModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + +class Qwen3OmniMoeThinkerProcessingInfo( + Qwen2AudioProcessingInfo, Qwen2_5_VLProcessingInfo +): + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen3OmniMoeConfig).thinker_config + + def get_hf_processor(self, **kwargs: object) -> Qwen3OmniMoeProcessor: + processor = self.ctx.get_hf_processor( + Qwen3OmniMoeProcessor, + use_fast=kwargs.pop("use_fast", True), + **kwargs, + ) + if not hasattr(processor, "audio_token"): + processor.audio_token = "<|audio_pad|>" + if not hasattr(processor, "image_token"): + processor.image_token = "<|image_pad|>" + if not hasattr(processor, "video_token"): + processor.video_token = "<|video_pad|>" + return processor + + def get_feature_extractor(self, **kwargs: object): + hf_processor = self.get_hf_processor(**kwargs) + feature_extractor = hf_processor.feature_extractor # type: ignore + assert isinstance(feature_extractor, WhisperFeatureExtractor) + return feature_extractor + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"audio": None, "image": None, "video": None} + + +Qwen3OmniMoeThinkerDummyInputsBuilder = Qwen2_5OmniThinkerDummyInputsBuilder + + +class Qwen3OmniMoeThinkerMultiModalProcessor( + Qwen2_5OmniThinkerMultiModalProcessor, +): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) + + def pad_to_hop_length(x: np.ndarray, hop_length: int) -> np.ndarray: + length = x.shape[-1] + if length % hop_length != 0: + pad_length = hop_length - (length % hop_length) + x = np.pad(x, (0, pad_length), mode="constant", constant_values=0) + return x + + # NOTE: WhisperFeatureExtractor cannot handle empty list of audios + feature_extractor = self.info.get_feature_extractor() + hop_length = feature_extractor.hop_length + if audios: + # NOTE: Qwen3-Omni processor accept "audio" + # To make sure the cache works with padding=True, we pre-padded + # the audio to multiple of hop_length. + mm_data["audio"] = [ + pad_to_hop_length(audio, hop_length) + if isinstance(audio, np.ndarray) + else (pad_to_hop_length(audio[0], hop_length), audio[1]) + for audio in audios + ] + + # TODO(Isotr0py): Remove this patch after upstream fix PR + # released and Transformers version update: + # https://github.com/huggingface/transformers/pull/41473 + mm_kwargs = dict(mm_kwargs) + tok_kwargs = dict(tok_kwargs) + if Version(TRANSFORMERS_VERSION) < Version("4.58.0"): + # move truncation to audio_kwargs level to avoid conflict + # with tok_kwargs + mm_kwargs["audio_kwargs"] = { + "truncation": mm_kwargs.pop("truncation", False) + } + mm_kwargs["text_kwargs"] = { + "truncation": tok_kwargs.pop("truncation", False) + } + + hf_inputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + if ( + "audio_feature_lengths" in hf_inputs + and "feature_attention_mask" in hf_inputs + and (audios := mm_data.get("audio", [])) + ): + audio_num_frames = [] + for _, audio in enumerate(audios): + audio_length = len(audio[0]) if isinstance(audio, tuple) else len(audio) + num_frame = ( + (audio_length // hop_length) + if audio_length % hop_length == 0 + else (audio_length // hop_length - 1) + ) + if mm_kwargs.get("truncation", False): + num_frame = min( + num_frame, feature_extractor.n_samples // hop_length + ) + audio_num_frames.append(num_frame) + hf_inputs["feature_attention_mask"] = [ + torch.ones(num_frame) for num_frame in audio_num_frames + ] + hf_inputs["audio_feature_lengths"] = torch.tensor(audio_num_frames) + return hf_inputs + + def _maybe_apply_prompt_updates( + self, + mm_items: MultiModalDataItems, + prompt_ids: list[int], + mm_kwargs: MultiModalKwargsItems, + mm_prompt_updates: MultiModalPromptUpdates, + is_update_applied: bool, + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + """ + Qwen3-Omni reimplements this function to handle `use_audio_in_video`. + """ + mm_item_counts = mm_items.get_all_counts() + self._validate_mm_kwargs(mm_kwargs, mm_item_counts) + + use_audio_in_video = False + if "video" in mm_kwargs: + for item in mm_kwargs["video"]: + if item and item["use_audio_in_video"].data: + use_audio_in_video = True + else: + use_audio_in_video = False + + if use_audio_in_video and "video" in mm_item_counts: + assert "audio" in mm_item_counts + mm_item_counts["audio"] -= mm_item_counts["video"] + + # Special case with `use_audio_in_video=True` + if use_audio_in_video: + if is_update_applied: + prompt_ids = self._get_raw_input_ids(prompt_ids, use_audio_in_video) + ( + prompt_ids, + mm_placeholders, + ) = self._apply_prompt_updates( + prompt_ids, + mm_prompt_updates, + ) + self._validate_mm_placeholders(mm_placeholders, mm_item_counts) + # normal case with `use_audio_in_video=False` + elif is_update_applied: + mm_placeholders = self._find_mm_placeholders( + prompt_ids, + mm_prompt_updates, + ) + self._validate_mm_placeholders( + mm_placeholders, + mm_item_counts, + ) + else: + prompt_ids, mm_placeholders = self._apply_prompt_updates( + prompt_ids, + mm_prompt_updates, + ) + self._validate_mm_placeholders( + mm_placeholders, + mm_item_counts, + ) + + return prompt_ids, mm_placeholders + + def get_updates_use_audio_in_video( + self, + thinker_config: PretrainedConfig, + audio_len: int, + video_grid_thw: list[int] | torch.Tensor, + video_second_per_grid_t: float, + ) -> list[int]: + shift = 0 + audio_token_id = thinker_config.audio_token_id + video_token_id = thinker_config.video_token_id + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + position_id_per_seconds = thinker_config.position_id_per_seconds + audio_token_indices = np.arange(next(iter([audio_len]))) + curr_video_grid_thw = next(iter([video_grid_thw])) + height = curr_video_grid_thw[1] // spatial_merge_size + width = curr_video_grid_thw[2] // spatial_merge_size + video_token_indices = np.arange(curr_video_grid_thw[0]).reshape(-1, 1, 1) + video_token_indices = np.broadcast_to( + video_token_indices, (video_token_indices.shape[0], height, width) + ).reshape(-1) + video_token_indices = ( + (video_token_indices + shift) + * next(iter([video_second_per_grid_t])) + * position_id_per_seconds + ) + video_data_index, audio_data_index = 0, 0 + updates = [audio_start_token_id] + while video_data_index < len(video_token_indices) and audio_data_index < len( + audio_token_indices + ): + if ( + video_token_indices[video_data_index] + <= audio_token_indices[audio_data_index] + ): + updates += [video_token_id] + video_data_index += 1 + else: + updates += [audio_token_id] + audio_data_index += 1 + if video_data_index < len(video_token_indices): + updates += [video_token_id] * (len(video_token_indices) - video_data_index) + if audio_data_index < len(audio_token_indices): + updates += [audio_token_id] * (len(audio_token_indices) - audio_data_index) + updates += [audio_end_token_id] + return updates + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) + vocab = tokenizer.get_vocab() + + audio_token = processor.audio_token + image_token = processor.image_token + video_token = processor.video_token + audio_token_id = vocab[audio_token] + image_token_id = vocab[image_token] + video_token_id = vocab[video_token] + + out_mm_data = out_mm_kwargs.get_data() + audio_feature_lengths = out_mm_data.get("audio_feature_lengths") + feature_attention_mask = out_mm_data.get("feature_attention_mask") + if audio_feature_lengths is None and feature_attention_mask is None: + audio_output_lengths = [] + elif audio_feature_lengths is not None: + _, audio_output_lens = _get_feat_extract_output_lengths( + audio_feature_lengths + ) + audio_output_lengths = audio_output_lens.tolist() + elif feature_attention_mask is not None: + assert isinstance(feature_attention_mask, torch.Tensor) + _, audio_output_lens = _get_feat_extract_output_lengths( + feature_attention_mask.sum(-1) + ) + audio_output_lengths = audio_output_lens.tolist() + + # number of audios read from video. + audio_in_video_item_idx = 0 + audio_item_idx = 0 + + def get_replacement_qwen2_audio(item_idx: int): + nonlocal audio_item_idx + item_idx += audio_in_video_item_idx + + audio_item_idx += 1 + + num_features = audio_output_lengths[item_idx] + if num_features == 0: + audios = mm_items.get_items("audio", AudioProcessorItems) + audio = audios.get(item_idx) + raise ValueError( + f"The audio {audio} (len={len(audio)}) is too short " + "to be represented inside the model" + ) + + return [audio_token_id] * num_features + + def get_replacement_qwen2_vision(item_idx: int, modality: str): + grid_thw = out_mm_data[f"{modality}_grid_thw"][item_idx] + assert isinstance(grid_thw, torch.Tensor) + merge_length = image_processor.merge_size**2 + + token_id = image_token_id if modality == "image" else video_token_id + return [token_id] * (int(grid_thw.prod()) // merge_length) + + use_audio_in_video = hf_processor_mm_kwargs.get("use_audio_in_video", False) + thinker_config = self.info.get_hf_config() + + def get_replacement_qwen2_use_audio_in_video(item_idx: int): + nonlocal audio_in_video_item_idx + audio_num_features = audio_output_lengths[audio_item_idx + item_idx] + video_grid_thw = out_mm_data["video_grid_thw"][item_idx] + + audio_in_video_item_idx += 1 + + second_per_grid_ts = hf_processor_mm_kwargs.get("second_per_grid_ts", None) + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[item_idx] + else: + video_second_per_grid_t = 1.0 + + return self.get_updates_use_audio_in_video( + thinker_config=thinker_config, + audio_len=audio_num_features, + video_grid_thw=video_grid_thw, + video_second_per_grid_t=video_second_per_grid_t, + ) + + video_replacement_fn = ( + get_replacement_qwen2_use_audio_in_video + if use_audio_in_video + else partial(get_replacement_qwen2_vision, modality="video") + ) + + return [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_qwen2_audio, + ), + PromptReplacement( + modality="image", + target=image_token, + replacement=partial(get_replacement_qwen2_vision, modality="image"), + ), + PromptReplacement( + modality="video", + target=video_token, + replacement=video_replacement_fn, + ), + ] + + def _validate_mm_placeholders( + self, + mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], + mm_item_counts: Mapping[str, int], + ) -> None: + BaseMultiModalProcessor[ + Qwen2_5OmniThinkerProcessingInfo + ]._validate_mm_placeholders(self, mm_placeholders, mm_item_counts) + + def _get_raw_input_ids( + self, + token_ids: list[int], + use_audio_in_video: bool = False, + ) -> list[int]: + tokenizer = self.info.get_tokenizer() + vision_bos_token = tokenizer.encode(tokenizer.vision_bos_token)[0] + vision_eos_token = tokenizer.encode(tokenizer.vision_eos_token)[0] + audio_bos_token = tokenizer.encode(tokenizer.audio_bos_token)[0] + audio_eos_token = tokenizer.encode(tokenizer.audio_eos_token)[0] + audio_token = tokenizer.encode("<|audio_pad|>")[0] + image_token = tokenizer.encode("<|image_pad|>")[0] + video_token = tokenizer.encode("<|video_pad|>")[0] + + result = token_ids[:] + if use_audio_in_video: + while True: + start = None + for i in range(len(result) - 1): + if result[i : i + 2] == [vision_bos_token, audio_bos_token]: + start = i + break + if start is not None: + end = None + for i in range(start + 2, len(result) - 1): + if result[i : i + 2] == [audio_eos_token, vision_eos_token]: + end = i + break + if end is not None: + result = ( + result[:start] + + [vision_bos_token, video_token, vision_eos_token] + + result[end + 2 :] + ) + else: + break + + for mm_token in [audio_token, image_token, video_token]: + compressed = [] + for x in result: + if x != mm_token or (not compressed or compressed[-1] != mm_token): + compressed.append(x) + result = compressed + + return result + + +class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMixin): + def _validate_and_reshape_mm_tensor( + self, mm_input: object, name: str, dim: int = 0 + ) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") + if name == "feature_attention_mask": + dim = -1 + if isinstance(mm_input, torch.Tensor): + return torch.concat(list(mm_input), dim=dim) + else: + if isinstance(mm_input[0], list): + return torch.concat( + [torch.concat(mm_input[i], dim=dim) for i in range(len(mm_input))], + dim=dim, + ) + else: + return torch.concat(mm_input, dim=dim) + + def _process_audio_input( + self, + audio_input: Qwen2AudioFeatureInputs, + audio_hashes: list[str] = None, + cached_audio_features: torch.Tensor = None, + ) -> torch.Tensor: + input_features = audio_input["input_features"] + audio_feature_lengths = audio_input["audio_feature_lengths"] + + if input_features.ndim == 3: + assert input_features.shape[0] == 1 + input_features = input_features.squeeze(0) + + if not isinstance(audio_feature_lengths, torch.Tensor): + audio_feature_lengths = torch.cat(audio_feature_lengths) + if audio_feature_lengths.ndim == 2: + audio_feature_lengths = audio_feature_lengths.reshape(-1) + + audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths( + audio_feature_lengths + ) + + audio_outputs = self.audio_tower( + input_features.to(self.audio_tower.dtype), + feature_lens=audio_feature_lengths, + aftercnn_lens=audio_feat_lengths, + ) + audio_features = audio_outputs.last_hidden_state + return audio_features.split(audio_output_lengths.tolist()) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen3OmniMoeThinkerMultiModalProcessor, + info=Qwen3OmniMoeThinkerProcessingInfo, + dummy_inputs=Qwen3OmniMoeThinkerDummyInputsBuilder, +) +class Qwen3OmniMoeThinkerForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsPP, + SupportsMRoPE, + Qwen3OmniMoeConditionalGenerationMixin, +): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "thinker.lm_head.": "language_model.lm_head.", + "thinker.model.": "language_model.model.", + "thinker.": "", + } + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return "<|vision_start|><|image_pad|><|vision_end|>" + if modality.startswith("video"): + return "<|vision_start|><|video_pad|><|vision_end|>" + if modality.startswith("audio"): + return "<|audio_start|><|audio_pad|><|audio_end|>" + + raise ValueError("Only image, video or audio modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + thinker_config: Qwen3OmniMoeThinkerConfig = ( + vllm_config.model_config.hf_config.thinker_config + ) + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = thinker_config + self.multimodal_config = multimodal_config + + # force "use_flash_attention_2=True" to audio tower to align + # the results. + if flash_attn is not None: + audio_config = thinker_config.audio_config + audio_config._attn_implementation_autoset = True + audio_config._attn_implementation = "flash_attention_2" + else: + logger.warning( + "flash_attn is not available, the model may not yield the " + "exactly same result as the transformers implementation " + "in the audio tower part." + ) + + self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config) + + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) + self.visual = Qwen3Omni_VisionTransformer( + vision_config=thinker_config.vision_config, + norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + attn_backend_override=attn_backend_override, + ) + self.quant_config = quant_config + + self.language_model = Qwen3MoeLLMForCausalLM( + vllm_config=vllm_config.with_hf_config( + thinker_config.text_config, architectures=["Qwen3MoeForCausalLM"] + ), + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + self.use_deepstack = hasattr( + thinker_config.vision_config, "deepstack_visual_indexes" + ) + self.deepstack_num_level = ( + len(thinker_config.vision_config.deepstack_visual_indexes) + if self.use_deepstack + else 0 + ) + # register buffer for deepstack + self.deepstack_input_embeds = ( + [ + torch.zeros( + vllm_config.scheduler_config.max_num_batched_tokens, + thinker_config.text_config.hidden_size, + ) + for _ in range(self.deepstack_num_level) + ] + if self.use_deepstack + else None + ) + self.visual_dim = thinker_config.vision_config.out_hidden_size + self.multiscale_dim = self.visual_dim * self.deepstack_num_level + + def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors: + # get deepstack_input_embeds from buffer, and clear the buffer + return IntermediateTensors( + { + f"deepstack_input_embeds_{idx}": self.deepstack_input_embeds[idx][ + :num_tokens + ] + for idx in range(self.deepstack_num_level) + } + ) + + def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None: + # set deepstack_input_embeds to buffer + num_tokens = deepstack_input_embeds.size(1) + if num_tokens > self.deepstack_input_embeds[0].size(0): + self.deepstack_input_embeds = [ + torch.zeros( + num_tokens, + self.config.text_config.hidden_size, + device=self.deepstack_input_embeds[0].device, + dtype=self.deepstack_input_embeds[0].dtype, + ) + for _ in range(self.deepstack_num_level) + ] + for idx in range(self.deepstack_num_level): + self.deepstack_input_embeds[idx][:num_tokens].copy_( + deepstack_input_embeds[idx] + ) + + def _clear_deepstack_input_embeds(self, num_tokens: int) -> None: + # clear deepstack_input_embeds in buffer + if num_tokens > 0: + for idx in range(self.deepstack_num_level): + self.deepstack_input_embeds[idx][:num_tokens].zero_() + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + mm_input_by_modality = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "video" not in mm_input_by_modality + ): + mm_input_by_modality["video"] = self._parse_and_validate_video_input( + **kwargs + ) + if ( + input_key in ("input_audio_features") + and "audio" not in mm_input_by_modality + ): + mm_input_by_modality["audio"] = self._parse_and_validate_audio_input( + **kwargs + ) + return mm_input_by_modality + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings( + self, **kwargs: object + ) -> MultiModalEmbeddings | None: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) + if not mm_input_by_modality: + return [] + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + image_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += tuple(image_embeddings) + if modality == "video": + video_embeddings = self._process_video_input(multimodal_input) + multimodal_embeddings += tuple(video_embeddings) + if modality == "audio": + audio_embeddings = self._process_audio_input(multimodal_input) + multimodal_embeddings += tuple(audio_embeddings) + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + inputs_embeds = self._get_text_embeddings( + input_ids, + self.language_model.get_input_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + deepstack_input_embeds = None + # TODO (ywang96): support overlapping modalitiy embeddings so that + # `use_audio_in_video` will work on V1. + # split the feat dim to obtain multi-scale visual feature + has_vision_embeddings = [ + embeddings.shape[-1] != self.config.text_config.hidden_size + for embeddings in multimodal_embeddings + ] + if self.visual.deepstack_visual_indexes is not None and any( + has_vision_embeddings + ): + multiscale_len = len(self.visual.deepstack_visual_indexes) + multimodal_embeddings_multiscale = [] + is_vision = torch.zeros_like(is_multimodal) + mm_positions = torch.nonzero(is_multimodal, as_tuple=True)[0] + mm_position_idx = 0 + for index, embeddings in enumerate(multimodal_embeddings): + num_tokens = embeddings.shape[0] + current_positions = mm_positions[ + mm_position_idx : mm_position_idx + num_tokens + ] + + # Vision embeddings + if embeddings.shape[-1] != self.config.text_config.hidden_size: + visual_dim = embeddings.shape[-1] // (multiscale_len + 1) + multi_dim = visual_dim * multiscale_len + embeddings_main, embeddings_multiscale = torch.split( + embeddings, [visual_dim, multi_dim], dim=-1 + ) + multimodal_embeddings[index] = embeddings_main + multimodal_embeddings_multiscale.append(embeddings_multiscale) + is_vision[current_positions] = True + + # Audio embeddings + else: + is_vision[current_positions] = False + + mm_position_idx += num_tokens + + deepstack_input_embeds = inputs_embeds.new_zeros( + inputs_embeds.size(0), multiscale_len * inputs_embeds.size(1) + ) + deepstack_input_embeds = _merge_multimodal_embeddings( + inputs_embeds=deepstack_input_embeds, + multimodal_embeddings=multimodal_embeddings_multiscale, + is_multimodal=is_vision, + ) + deepstack_input_embeds = ( + deepstack_input_embeds.view( + inputs_embeds.shape[0], multiscale_len, visual_dim + ) + .permute(1, 0, 2) + .contiguous() + ) + self._set_deepstack_input_embeds(deepstack_input_embeds) + + inputs_embeds = _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + if ( + self.use_deepstack + and inputs_embeds is not None + and get_pp_group().is_first_rank + ): + deepstack_input_embeds = self._get_deepstack_input_embeds( + inputs_embeds.size(0) + ) + else: + deepstack_input_embeds = None + + hidden_states = self.language_model.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + # args for deepstack + deepstack_input_embeds=deepstack_input_embeds, + ) + + if inputs_embeds is not None and get_pp_group().is_first_rank: + self._clear_deepstack_input_embeds(inputs_embeds.size(0)) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["talker.", "code2wav."], + ) + loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + return loaded_weights + + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor | None, + video_grid_thw: list[list[int]] | torch.Tensor | None, + second_per_grid_ts: list[float] | None = None, + context_len: int = 0, + seq_len: int | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + config = hf_config.thinker_config + if isinstance(image_grid_thw, list): + image_grid_thw = torch.tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = torch.tensor(video_grid_thw) + input_ids = torch.tensor(input_tokens) + if input_ids is None or input_ids.ndim != 1: + raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids") + + seq_len = input_ids.shape[0] + if audio_feature_lengths is not None and not isinstance( + audio_feature_lengths, torch.Tensor + ): + audio_feature_lengths = torch.as_tensor( + audio_feature_lengths, dtype=torch.long + ) + if second_per_grid_ts is None: + if video_grid_thw is not None and video_grid_thw.numel() > 0: + second_per_grids = torch.ones( + video_grid_thw.shape[0], dtype=torch.float32 + ) + else: + second_per_grids = torch.tensor([], dtype=torch.float32) + else: + second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32) + + spatial_merge_size = config.vision_config.spatial_merge_size + image_token_id = config.image_token_id + video_token_id = config.video_token_id + audio_token_id = config.audio_token_id + vision_start_token_id = config.vision_start_token_id + audio_start_token_id = config.audio_start_token_id + position_id_per_seconds = config.position_id_per_seconds + + vision_start_indices = torch.argwhere( + input_ids == vision_start_token_id + ).squeeze(1) + if vision_start_indices.numel() > 0: + vision_tokens = input_ids[vision_start_indices + 1] + else: + vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype) + audio_nums = torch.sum(input_ids == audio_start_token_id) + image_nums = (vision_tokens == image_token_id).sum() + video_nums = ( + (vision_tokens == audio_start_token_id).sum() + if use_audio_in_video + else (vision_tokens == video_token_id).sum() + ) + + llm_pos_ids_list: list[torch.Tensor] = [] + st = 0 + image_idx = 0 + video_idx = 0 + audio_idx = 0 + remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501 + multimodal_nums = ( + image_nums + audio_nums + if use_audio_in_video + else image_nums + video_nums + audio_nums + ) # noqa: E501 + + for _ in range(multimodal_nums): + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + if (image_token_id in input_tokens or video_token_id in input_tokens) and ( + remain_videos > 0 or remain_images > 0 + ): + ed_vision_start = input_tokens.index(vision_start_token_id, st) + else: + ed_vision_start = len(input_tokens) + 1 + if audio_token_id in input_tokens and remain_audios > 0: + ed_audio_start = input_tokens.index(audio_start_token_id, st) + else: + ed_audio_start = len(input_tokens) + 1 + min_ed = min(ed_vision_start, ed_audio_start) + + if min_ed == ed_audio_start: + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append( + torch.arange(text_len, dtype=torch.long) + .view(1, -1) + .expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + llm_pos_ids_list.append( + torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + _, audio_len = _get_feat_extract_output_lengths( + audio_feature_lengths[audio_idx] + ) + llm_pos_ids = ( + torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + llm_pos_ids_list.append(llm_pos_ids) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + llm_pos_ids_list.append( + torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + st += text_len + bos_len + audio_len + eos_len + audio_idx += 1 + remain_audios -= 1 + elif ( + min_ed == ed_vision_start + and input_ids[ed_vision_start + 1] == image_token_id + ): + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append( + torch.arange(text_len, dtype=torch.long) + .view(1, -1) + .expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + llm_pos_ids_list.append( + torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = torch.arange(grid_t) * position_id_per_seconds + llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + llm_pos_ids_list.append( + torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + st += text_len + bos_len + image_len + eos_len + image_idx += 1 + remain_images -= 1 + elif ( + min_ed == ed_vision_start + and input_ids[ed_vision_start + 1] == video_token_id + and not use_audio_in_video + ): + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append( + torch.arange(text_len, dtype=torch.long) + .view(1, -1) + .expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + llm_pos_ids_list.append( + torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t) + * float(second_per_grids[video_idx].item()) + * position_id_per_seconds + ) + llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + llm_pos_ids_list.append( + torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + st += text_len + bos_len + video_len + eos_len + video_idx += 1 + remain_videos -= 1 + elif ( + min_ed == ed_vision_start + and ed_vision_start + 1 == ed_audio_start + and use_audio_in_video + ): + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + llm_pos_ids_list.append( + torch.arange(text_len, dtype=torch.long) + .view(1, -1) + .expand(3, -1) + + st_idx + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + bos_len = 1 + bos_block = ( + torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + llm_pos_ids_list.append(bos_block) + llm_pos_ids_list.append(bos_block) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + _, audio_len = _get_feat_extract_output_lengths( + audio_feature_lengths[audio_idx] + ) + audio_llm_pos_ids = ( + torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t) + * float(second_per_grids[video_idx].item()) + * position_id_per_seconds + ) + video_llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_data_index, audio_data_index = 0, 0 + while ( + video_data_index < video_llm_pos_ids.shape[-1] + and audio_data_index < audio_llm_pos_ids.shape[-1] + ): + if ( + video_llm_pos_ids[0][video_data_index] + <= audio_llm_pos_ids[0][audio_data_index] + ): + llm_pos_ids_list.append( + video_llm_pos_ids[ + :, video_data_index : video_data_index + 1 + ] + ) + video_data_index += 1 + else: + llm_pos_ids_list.append( + audio_llm_pos_ids[ + :, audio_data_index : audio_data_index + 1 + ] + ) + audio_data_index += 1 + if video_data_index < video_llm_pos_ids.shape[-1]: + llm_pos_ids_list.append( + video_llm_pos_ids[ + :, video_data_index : video_llm_pos_ids.shape[-1] + ] + ) + if audio_data_index < audio_llm_pos_ids.shape[-1]: + llm_pos_ids_list.append( + audio_llm_pos_ids[ + :, audio_data_index : audio_llm_pos_ids.shape[-1] + ] + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + eos_len = 1 + eos_block = ( + torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + llm_pos_ids_list.append(eos_block) + llm_pos_ids_list.append(eos_block) + st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501 + audio_idx += 1 + video_idx += 1 + remain_videos -= 1 + remain_audios -= 1 + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1) + + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + if llm_positions.shape[1] != seq_len: + raise RuntimeError("Position ids length mismatch with input ids length") + + mrope_position_delta = llm_positions.max() + 1 - seq_len + return llm_positions, mrope_position_delta diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 1c532376256d..940fa50ff803 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -24,15 +24,17 @@ # limitations under the License. """Inference-only Qwen3VL model compatible with HuggingFace weights.""" -from collections.abc import Iterable, Mapping, Sequence +import math +from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial -from typing import Any, Callable, Optional, Union +from itertools import islice +from typing import Any import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from transformers import BatchFeature +from transformers import BatchFeature, PretrainedConfig from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( smart_resize as image_smart_resize, @@ -55,7 +57,11 @@ from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY -from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead @@ -78,11 +84,12 @@ ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of +from vllm.utils.collection_utils import is_list_of from .interfaces import ( MultiModalEmbeddings, SupportsLoRA, + SupportsMRoPE, SupportsMultiModal, SupportsPP, ) @@ -105,7 +112,11 @@ _merge_multimodal_embeddings, maybe_prefix, ) -from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model +from .vision import ( + conv3d_to_linear_weight, + get_vit_attn_backend, + run_dp_sharded_mrope_vision_model, +) logger = init_logger(__name__) @@ -127,18 +138,15 @@ def __init__( self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = nn.Conv3d( - in_channels, + self.proj = ReplicatedLinear( + in_channels * math.prod(kernel_size), hidden_size, - kernel_size=kernel_size, - stride=kernel_size, bias=True, + return_bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) - x = self.proj(x).view(L, self.hidden_size) + x = self.proj(x) return x @@ -149,7 +157,7 @@ def __init__( hidden_features: int, bias: bool = False, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ): @@ -186,8 +194,8 @@ def __init__( num_heads: int, mlp_hidden_dim: int, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, + norm_layer: Callable[[int], nn.Module] | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, attn_backend: _Backend = _Backend.TORCH_SDPA, @@ -223,8 +231,8 @@ def forward( x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + max_seqlen: int | None = None, # Only used for Flash Attention + seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: x = x + self.attn( self.norm1(x), @@ -243,10 +251,10 @@ def __init__( self, d_model: int, context_dim: int, - norm_layer: Optional[Callable[[int], nn.Module]] = None, + norm_layer: Callable[[int], nn.Module] | None = None, spatial_merge_size: int = 2, use_postshuffle_norm: bool = False, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ) -> None: @@ -295,9 +303,10 @@ def __init__( self, vision_config: Qwen3VLVisionConfig, norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -357,7 +366,9 @@ def __init__( ) self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype() + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) use_upstream_fa = False if ( @@ -377,7 +388,6 @@ def __init__( raise RuntimeError( f"Qwen3-VL does not support {self.attn_backend} backend now." ) - self.blocks = nn.ModuleList( [ Qwen3_VisionBlock( @@ -465,8 +475,6 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij") h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij") h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij") - h_floor_grid_idx = h_floor_grid * num_grid_per_side - h_ceil_grid_idx = h_ceil_grid * num_grid_per_side # original computation of weights # w00 = (1 - dh_grid) * (1 - dw_grid) @@ -478,28 +486,25 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: w11 = dh_grid * dw_grid w10 = dh_grid - w11 w01 = dw_grid - w11 - w00 = 1 - dh_grid - dw_grid + w11 + w00 = 1 - dh_grid - w01 - idx00 = h_floor_grid_idx + w_floor_grid - idx01 = h_floor_grid_idx + w_ceil_grid - idx10 = h_ceil_grid_idx + w_floor_grid - idx11 = h_ceil_grid_idx + w_ceil_grid + h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid]) + w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid]) + h_grid_idx = h_grid * num_grid_per_side - indices = torch.stack([idx00, idx01, idx10, idx11], dim=0).reshape(4, -1) + indices = (h_grid_idx + w_grid).reshape(4, -1) weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) - weights = weights.to(dtype=self.dtype, device=self.device) + weights = weights.to(dtype=self.dtype) embeds = self.pos_embed(indices) weighted_embeds = embeds * weights - p0, p1, p2, p3 = weighted_embeds.unbind(dim=0) - combined = p0 + p1 + p2 + p3 + combined = weighted_embeds.sum(dim=0) - combined = combined.view(h * w, hidden_dim) - repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous() - repeated = repeated.view( - t, h // m_size, m_size, w // m_size, m_size, hidden_dim + combined = combined.reshape( + h // m_size, m_size, w // m_size, m_size, hidden_dim ) - repeated = repeated.permute(0, 1, 3, 2, 4, 5).reshape(-1, hidden_dim) + combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim) + repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim) outputs.append(repeated) return torch.cat(outputs, dim=0) @@ -507,7 +512,7 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor, - ) -> tuple[Optional[int], Optional[list[int]]]: + ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None if ( self.attn_backend == _Backend.FLASH_ATTN @@ -523,14 +528,15 @@ def forward( x: torch.Tensor, grid_thw: list[list[int]], ) -> torch.Tensor: - hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True) hidden_states = self.patch_embed(hidden_states) pos_embeds = self.fast_pos_embed_interpolate(grid_thw) hidden_states = hidden_states + pos_embeds rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, non_blocking=True) - grid_thw_tensor = torch.tensor(grid_thw, device=self.device, dtype=torch.int32) + grid_thw_tensor = torch.tensor(grid_thw, dtype=torch.int32) cu_seqlens = torch.repeat_interleave( grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], grid_thw_tensor[:, 0] @@ -538,11 +544,11 @@ def forward( dim=0, dtype=grid_thw_tensor.dtype if torch.jit.is_tracing() else torch.int32, ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) hidden_states = hidden_states.unsqueeze(1) - rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): @@ -576,6 +582,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params: set[str] = set() for name, loaded_weight in weights: + if name.endswith("patch_embed.proj.weight"): + loaded_weight = conv3d_to_linear_weight(loaded_weight) + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -620,9 +629,7 @@ def _get_vision_info( image_height: int, num_frames: int = 2, do_resize: bool = True, - image_processor: Optional[ - Union[Qwen2VLImageProcessorFast, Qwen3VLVideoProcessor] - ], + image_processor: Qwen2VLImageProcessorFast | Qwen3VLVideoProcessor | None, ) -> tuple[ImageSize, int]: if image_processor is None and num_frames > 1: image_processor = self.get_video_processor() @@ -721,8 +728,8 @@ def _get_video_second_idx( self, metadata: dict[str, Any], out_item: MultiModalKwargsItem, - do_sample_frames: Optional[bool] = None, - sampled_fps: Optional[float] = None, + do_sample_frames: bool | None = None, + sampled_fps: float | None = None, ) -> list[int]: video_processor = self.get_video_processor() merge_size = video_processor.merge_size @@ -739,9 +746,9 @@ def _get_video_second_idx( if do_sample_frames: # here video_fps is the fps of the sampled video, and # metadata["fps"] refers to the fps of the original video. - video_fps = sampled_fps if sampled_fps else video_processor.fps + sampled_fps = sampled_fps if sampled_fps else video_processor.fps total_num_frames = metadata["total_num_frames"] - num_frames = int(total_num_frames / metadata["fps"] * video_fps) + num_frames = int(total_num_frames / metadata["fps"] * sampled_fps) num_frames = min( min( max(num_frames, video_processor.min_frames), @@ -773,7 +780,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -891,16 +898,12 @@ def _call_hf_processor( processor = self.info.get_hf_processor(**mm_kwargs) # Separate video processing from image processing. Because the videos - # are processed into serval image patches - if ( - "videos" in mm_data - and isinstance(mm_data["videos"], list) - and len(mm_data["videos"]) > 0 - ): + # are processed into several image patches + if videos := mm_data.pop("videos", []): video_grid_thw_lst = [] pixel_values_videos_lst = [] - for item_idx, item in enumerate(mm_data.pop("videos", [])): + for item in videos: video_array, metadata = item # NOTE: @JJJYmmm new attr metadata.frames_indices indicates @@ -1091,11 +1094,11 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, # args for deepstack - deepstack_input_embeds: Optional[IntermediateTensors] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + deepstack_input_embeds: IntermediateTensors | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -1106,11 +1109,9 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer_idx, layer in enumerate( - self.layers[self.start_layer : self.end_layer] + for layer_idx, layer in islice( + enumerate(self.layers), self.start_layer, self.end_layer ): - layer_idx = layer_idx + self.start_layer - hidden_states, residual = layer( positions, hidden_states, @@ -1172,7 +1173,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): dummy_inputs=Qwen3VLDummyInputsBuilder, ) class Qwen3VLForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): packed_modules_mapping = { "qkv_proj": [ @@ -1198,7 +1199,7 @@ class Qwen3VLForConditionalGeneration( ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<|vision_start|><|image_pad|><|vision_end|>" if modality.startswith("video"): @@ -1220,12 +1221,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): ) and not multimodal_config.get_limit_per_prompt("video"): self.visual = None else: + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.visual = Qwen3_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), use_data_parallel=self.use_data_parallel, + attn_backend_override=attn_backend_override, ) self.language_model = Qwen3LLMForCausalLM( @@ -1305,13 +1312,13 @@ def _validate_and_reshape_mm_tensor( f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})" ) - return torch.concat(list(mm_input)) + return mm_input.reshape(-1, mm_input.shape[-1]) else: return torch.concat(mm_input) def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[Qwen2_5_VLImageInputs]: + ) -> Qwen2_5_VLImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1360,7 +1367,7 @@ def _parse_and_validate_image_input( def _parse_and_validate_video_input( self, **kwargs: object - ) -> Optional[Qwen2_5_VLVideoInputs]: + ) -> Qwen2_5_VLVideoInputs | None: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -1478,12 +1485,121 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: ) return mm_input_by_modality + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: list[list[int]] | torch.Tensor, + video_grid_thw: list[list[int]] | torch.Tensor, + context_len: int = 0, + seq_len: int | None = None, + second_per_grid_ts: list[float] | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + + video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw for _ in range(t)] + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id + ).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + return llm_positions, mrope_position_delta + def get_language_model(self) -> torch.nn.Module: return self.language_model def get_multimodal_embeddings( self, **kwargs: object - ) -> Optional[MultiModalEmbeddings]: + ) -> MultiModalEmbeddings | None: mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return None @@ -1497,11 +1613,11 @@ def get_multimodal_embeddings( for modality in mm_input_by_modality: multimodal_input = mm_input_by_modality[modality] if modality == "image": - vision_embeddings = self._process_image_input(multimodal_input) - multimodal_embeddings += vision_embeddings + image_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += tuple(image_embeddings) if modality == "video": video_embeddings = self._process_video_input(multimodal_input) - multimodal_embeddings += video_embeddings + multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings def _compute_deepstack_embeds( @@ -1548,9 +1664,9 @@ def _compute_deepstack_embeds( def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, *, - is_multimodal: Optional[torch.Tensor] = None, + is_multimodal: torch.Tensor | None = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: inputs_embeds = self._get_text_embeddings( @@ -1589,12 +1705,6 @@ def get_input_embeddings( ) if deepstack_input_embeds is not None: - deepstack_input_embeds = ( - torch.zeros_like(inputs_embeds) - .unsqueeze(0) - .repeat(self.deepstack_num_level, 1, 1) - .contiguous() - ) self._set_deepstack_input_embeds(deepstack_input_embeds) return inputs_embeds @@ -1603,10 +1713,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: """Run forward pass for Qwen3VL. Args: @@ -1662,7 +1772,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -1678,6 +1788,6 @@ def get_mm_mapping(self) -> MultiModelKeys: """ return MultiModelKeys.from_string_field( language_model="language_model", - connector="model.visual.merger", - tower_model="model.visual.", + connector="visual.merger", + tower_model="visual.", ) diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py index cd8046d04248..284b1301d07f 100644 --- a/vllm/model_executor/models/qwen3_vl_moe.py +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -25,8 +25,8 @@ """Inference-only Qwen3-VL-MoE model compatible with HuggingFace weights.""" import typing -from collections.abc import Iterable -from typing import Callable, Optional, Union +from collections.abc import Callable, Iterable +from itertools import islice import torch from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeConfig @@ -89,10 +89,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - deepstack_input_embeds: Optional[IntermediateTensors] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + deepstack_input_embeds: IntermediateTensors | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -103,11 +103,9 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer_idx, layer in enumerate( - self.layers[self.start_layer : self.end_layer] + for layer_idx, layer in islice( + enumerate(self.layers), self.start_layer, self.end_layer ): - layer_idx = layer_idx + self.start_layer - hidden_states, residual = layer( positions, hidden_states, @@ -352,6 +350,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): dummy_inputs=Qwen3VLDummyInputsBuilder, ) class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super(Qwen3VLForConditionalGeneration, self).__init__() config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config @@ -378,6 +384,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.language_model = Qwen3MoeLLMForCausalLM( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model") ) + # Whether to include the gate_up_proj mapping is determined by + # the language model. + self.packed_modules_mapping = ( + self.packed_modules_mapping | self.language_model.packed_modules_mapping + ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 1786ea6a6878..f011229985c8 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -9,9 +9,9 @@ import copy import math import unicodedata -from collections.abc import Collection, Mapping, Sequence, Set +from collections.abc import Callable, Collection, Mapping, Sequence, Set from functools import lru_cache, partial -from typing import Annotated, Callable, Literal, Optional, Union +from typing import Annotated, Literal, TypeAlias import regex as re import torch @@ -93,7 +93,7 @@ class QwenImageEmbeddingInputs(TensorSchema): data: Annotated[torch.Tensor, TensorShape("bn", 256, "hs")] -QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs] +QwenImageInputs: TypeAlias = QwenImagePixelInputs | QwenImageEmbeddingInputs class VisualAttention(nn.Module): @@ -107,8 +107,8 @@ def __init__( embed_dim: int, num_heads: int, bias: bool = True, - kdim: Optional[int] = None, - vdim: Optional[int] = None, + kdim: int | None = None, + vdim: int | None = None, ): super().__init__() self.embed_dim = embed_dim @@ -135,7 +135,7 @@ def __init__( def forward( self, x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, + attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: # query/key/value: [sq, b, h] sq, b, _ = x.size() @@ -213,7 +213,7 @@ def __init__( self, hidden_size: int, intermediate_size: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ): super().__init__() self.c_fc = ColumnParallelLinear( @@ -241,7 +241,7 @@ def __init__( n_head: int, mlp_ratio: float = 4.0, norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ): super().__init__() @@ -258,7 +258,7 @@ def __init__( def attention( self, x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, + attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None return self.attn(x, attn_mask=attn_mask) @@ -266,7 +266,7 @@ def attention( def forward( self, x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, + attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) x = x + self.mlp(self.ln_2(x)) @@ -281,7 +281,7 @@ def __init__( heads: int, mlp_ratio: float = 4.0, norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, ): super().__init__() self.width = width @@ -307,7 +307,7 @@ def get_cast_device(self) -> torch.device: return self.resblocks[0].mlp.c_fc.weight.device def forward( - self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None + self, x: torch.Tensor, attn_mask: torch.Tensor | None = None ) -> torch.Tensor: for r in self.resblocks: x = r(x, attn_mask=attn_mask) @@ -326,7 +326,7 @@ def __init__( n_queries: int = 256, output_dim: int = 512, image_start_id: int = 151857, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, **kwargs, ): super().__init__() @@ -434,10 +434,10 @@ class TokenizerWithoutImagePad(tokenizer.__class__): # type: ignore def tokenize( self, text: str, - allowed_special: Union[Set[str], str] = "all", - disallowed_special: Union[Collection[str], str] = (), + allowed_special: Set[str] | str = "all", + disallowed_special: Collection[str] | str = (), **kwargs, - ) -> list[Union[bytes, str]]: + ) -> list[bytes | str]: text = unicodedata.normalize("NFC", text) return [ @@ -451,9 +451,9 @@ def tokenize( def _decode( self, - token_ids: Union[int, list[int]], + token_ids: int | list[int], skip_special_tokens: bool = False, - errors: Optional[str] = None, + errors: str | None = None, **kwargs, ) -> str: if isinstance(token_ids, int): @@ -523,9 +523,9 @@ def image_pad_tag(self) -> str: def __call__( self, - text: Optional[Union[TextInput, list[TextInput]]] = None, - images: Optional[Union[ImageInput, list[ImageInput]]] = None, - return_tensors: Optional[Union[str, TensorType]] = None, + text: TextInput | list[TextInput] | None = None, + images: ImageInput | list[ImageInput] | None = None, + return_tensors: str | TensorType | None = None, ) -> BatchFeature: if text is None: text = [] @@ -568,7 +568,7 @@ def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor: **kwargs, ) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_num_image_tokens(self) -> int: @@ -597,7 +597,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: hf_config = self.info.get_hf_config() vision_config = hf_config.visual @@ -722,7 +722,7 @@ def get_mm_mapping(self) -> MultiModelKeys: ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return f"Picture {i}: <img></img>" @@ -745,7 +745,7 @@ def __init__( def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[QwenImageInputs]: + ) -> QwenImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -799,10 +799,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None diff --git a/vllm/model_executor/models/radio.py b/vllm/model_executor/models/radio.py index 2313b98348b7..6a42564ac70a 100644 --- a/vllm/model_executor/models/radio.py +++ b/vllm/model_executor/models/radio.py @@ -11,7 +11,7 @@ import math from collections.abc import Iterable from itertools import repeat -from typing import Optional, Union +from typing import TypeAlias import torch import torch.nn as nn @@ -23,8 +23,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.intern_vit import InternVisionEncoder -input_dim_t = Union[int, tuple[int, int]] -norm_t = Union[tuple[float, float, float], torch.Tensor] +input_dim_t: TypeAlias = int | tuple[int, int] +norm_t: TypeAlias = tuple[float, float, float] | torch.Tensor def _ntuple(n): @@ -43,40 +43,14 @@ def parse(x): to_ntuple = _ntuple -class InputConditioner(nn.Module): - def __init__( - self, - input_scale: float, - norm_mean: norm_t, - norm_std: norm_t, - dtype: torch.dtype = None, - ): - super().__init__() - - self.dtype = dtype - - self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale) - self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale) - - def forward(self, x: torch.Tensor): - y = (x - self.norm_mean) / self.norm_std - if self.dtype is not None: - y = y.to(self.dtype) - return y - - -def _to_tensor(v: norm_t): - return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1) - - class ClsToken(nn.Module): def __init__( self, ndim: int, num_tokens: int = 1, enabled: bool = True, - register_multiple: Optional[int] = None, - num_registers: Optional[int] = None, + register_multiple: int | None = None, + num_registers: int | None = None, ): super().__init__() @@ -128,12 +102,12 @@ def __init__( abs_pos: bool = True, normalize_patches: bool = False, cls_token: bool = False, - max_input_dims: Optional[input_dim_t] = None, + max_input_dims: input_dim_t | None = None, pos_dropout: float = 0.0, return_pos_enc: bool = False, num_cls_tokens: int = 1, - register_multiple: Optional[int] = None, - num_registers: Optional[int] = None, + register_multiple: int | None = None, + num_registers: int | None = None, patch_bias: bool = False, device=None, dtype=None, @@ -275,8 +249,8 @@ def embed_patches(self, x: torch.Tensor) -> torch.Tensor: def apply_pos_enc( self, patches: torch.Tensor, - patch_idxs: Optional[torch.Tensor] = None, - input_size: Optional[tuple[int, int]] = None, + patch_idxs: torch.Tensor | None = None, + input_size: tuple[int, int] | None = None, ) -> torch.Tensor: if not self.abs_pos: return patches @@ -299,8 +273,8 @@ def apply_pos_enc( def get_pos_enc( self, batch_size: int, - patch_idxs: Optional[torch.Tensor] = None, - input_size: Optional[tuple[int, int]] = None, + patch_idxs: torch.Tensor | None = None, + input_size: tuple[int, int] | None = None, ) -> torch.Tensor: if input_size is None: input_dims = self.input_dims @@ -440,9 +414,9 @@ class RadioInternVisionModel(nn.Module): def __init__( self, config: PretrainedConfig = None, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, - num_hidden_layers_override: Optional[int] = None, + num_hidden_layers_override: int | None = None, num_dummy_heads: int = 0, prefix: str = "", ) -> None: @@ -472,7 +446,7 @@ def __init__( prefix=f"{prefix}.encoder", ) - def _init_img_size(self, patch_size, img_size: Union[int, tuple[int, int]]): + def _init_img_size(self, patch_size, img_size: int | tuple[int, int]): if img_size is None: return None, None, None img_size = to_2tuple(img_size) @@ -498,20 +472,15 @@ class RadioModel(nn.Module): def __init__( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, - num_hidden_layers_override: Optional[int] = None, + num_hidden_layers_override: int | None = None, num_dummy_heads: int = 0, prefix: str = "", ) -> None: super().__init__() self.config = config - self.input_conditioner = InputConditioner( - input_scale=1.0, - norm_mean=config.norm_mean, - norm_std=config.norm_std, - ) self.model = RadioInternVisionModel( config=config, quant_config=quant_config, @@ -522,11 +491,10 @@ def __init__( def forward( self, - pixel_values: Optional[torch.Tensor] = None, - pixel_embeds: Optional[torch.Tensor] = None, + pixel_values: torch.Tensor | None = None, + pixel_embeds: torch.Tensor | None = None, ) -> torch.FloatTensor: - x = self.input_conditioner(pixel_values) - y = self.model(x) + y = self.model(pixel_values) return self._extract_final(y) def load_weights(self, weights) -> set[str]: @@ -548,6 +516,10 @@ def load_weights(self, weights) -> set[str]: # Skip buffers not used in vLLM if sub in {"summary_idxs"}: continue + if sub.startswith("input_conditioner."): + # we normalize in the input processor, + # based on norm and std values from the config + continue vllm_key = None if sub.startswith("model.patch_generator."): diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 32e50f9a8e48..81d4a6bc5f3a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -14,11 +14,11 @@ import sys import tempfile from abc import ABC, abstractmethod -from collections.abc import Set +from collections.abc import Callable, Set from dataclasses import asdict, dataclass, field from functools import lru_cache from pathlib import Path -from typing import Callable, Optional, TypeVar, Union +from typing import TypeVar import torch.nn as nn import transformers @@ -44,7 +44,6 @@ supports_multimodal_raw_input_only, supports_pp, supports_transcription, - supports_v0_only, ) from .interfaces_base import ( get_default_pooling_type, @@ -61,9 +60,6 @@ "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"), "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), - "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), - "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), - "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), # baichuan-7b, upper case 'C' in the class name "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-13b, lower case 'c' in the class name @@ -88,8 +84,11 @@ "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"), "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"), - "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"), + "FalconForCausalLM": ("falcon", "FalconForCausalLM"), + "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), + "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"), + "FlexOlmoForCausalLM": ("flex_olmo", "FlexOlmoForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), @@ -126,11 +125,12 @@ "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"), - "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), - "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"), "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), + "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), + "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), + "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), # transformers's mpt class has lower case @@ -171,6 +171,7 @@ _EMBEDDING_MODELS = { # [Text-only] "BertModel": ("bert", "BertEmbeddingModel"), + "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "Gemma3TextModel": ("gemma3", "Gemma3Model"), @@ -208,6 +209,7 @@ ), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 + "SiglipModel": ("siglip", "SiglipEmbeddingModel"), # Technically Terratorch models work on images, both in # input and output. I am adding it here because it piggy-backs on embedding # models for the time being. @@ -246,6 +248,7 @@ "aya_vision", "AyaVisionForConditionalGeneration", ), + "BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"), "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), "ChameleonForConditionalGeneration": ( "chameleon", @@ -256,6 +259,7 @@ "Cohere2VisionForConditionalGeneration", ), "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), + "DeepseekOCRForCausalLM": ("deepseek_ocr", "DeepseekOCRForCausalLM"), "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"), "Ernie4_5_VLMoeForConditionalGeneration": ( "ernie45_vl", @@ -297,6 +301,10 @@ ), "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"), "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 + "LightOnOCRForConditionalGeneration": ( + "lightonocr", + "LightOnOCRForConditionalGeneration", + ), "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"), "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501 "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), @@ -354,6 +362,10 @@ "qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration", ), + "Qwen3OmniMoeForConditionalGeneration": ( + "qwen3_omni_moe_thinker", + "Qwen3OmniMoeThinkerForConditionalGeneration", + ), "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"), # noqa: E501 "Qwen3VLMoeForConditionalGeneration": ( "qwen3_vl_moe", @@ -396,32 +408,44 @@ # Text generation models "SmolLM3ForCausalLM": ("transformers", "TransformersForCausalLM"), # Multimodal models - "Emu3ForConditionalGeneration": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 + "Emu3ForConditionalGeneration": ( + "transformers", + "TransformersMultiModalForCausalLM", + ), } _TRANSFORMERS_BACKEND_MODELS = { + # Text generation models "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"), - "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 - "TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"), # noqa: E501 - "TransformersMoEForMultimodalLM": ( - "transformers_moe", - "TransformersMoEForMultimodalLM", + "TransformersMoEForCausalLM": ("transformers", "TransformersMoEForCausalLM"), + # Multimodal models + "TransformersMultiModalForCausalLM": ( + "transformers", + "TransformersMultiModalForCausalLM", + ), + "TransformersMultiModalMoEForCausalLM": ( + "transformers", + "TransformersMultiModalMoEForCausalLM", ), - "TransformersEmbeddingModel": ( - "transformers_pooling", - "TransformersEmbeddingModel", + # Embedding models + "TransformersEmbeddingModel": ("transformers", "TransformersEmbeddingModel"), + "TransformersMoEEmbeddingModel": ("transformers", "TransformersMoEEmbeddingModel"), + "TransformersMultiModalEmbeddingModel": ( + "transformers", + "TransformersMultiModalEmbeddingModel", ), + # Sequence classification models "TransformersForSequenceClassification": ( - "transformers_pooling", + "transformers", "TransformersForSequenceClassification", ), "TransformersMoEForSequenceClassification": ( - "transformers_pooling", + "transformers", "TransformersMoEForSequenceClassification", ), - "TransformersMoEEmbeddingModel": ( - "transformers_pooling", - "TransformersMoEEmbeddingModel", + "TransformersMultiModalForSequenceClassification": ( + "transformers", + "TransformersMultiModalForSequenceClassification", ), } @@ -473,7 +497,6 @@ class _ModelInfo: has_noops: bool supports_transcription: bool supports_transcription_only: bool - supports_v0_only: bool @staticmethod def from_model_cls(model: type[nn.Module]) -> "_ModelInfo": @@ -498,7 +521,6 @@ def from_model_cls(model: type[nn.Module]) -> "_ModelInfo": supports_transcription_only=( supports_transcription(model) and model.supports_transcription_only ), - supports_v0_only=supports_v0_only(model), has_noops=has_noops(model), ) @@ -578,7 +600,7 @@ def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None: # file not changed, use cached _ModelInfo properties return _ModelInfo(**mi_dict["modelinfo"]) except Exception: - logger.exception( + logger.debug( ("Cached model info for class %s.%s error. "), self.module_name, self.class_name, @@ -649,7 +671,7 @@ def load_model_cls(self) -> type[nn.Module]: def _try_load_model_cls( model_arch: str, model: _BaseRegisteredModel, -) -> Optional[type[nn.Module]]: +) -> type[nn.Module] | None: from vllm.platforms import current_platform current_platform.verify_model_arch(model_arch) @@ -664,7 +686,7 @@ def _try_load_model_cls( def _try_inspect_model_cls( model_arch: str, model: _BaseRegisteredModel, -) -> Optional[_ModelInfo]: +) -> _ModelInfo | None: try: return model.inspect_model_cls() except Exception: @@ -683,7 +705,7 @@ def get_supported_archs(self) -> Set[str]: def register_model( self, model_arch: str, - model_cls: Union[type[nn.Module], str], + model_cls: type[nn.Module] | str, ) -> None: """ Register an external model to be used in vLLM. @@ -751,13 +773,13 @@ def _raise_for_unsupported(self, architectures: list[str]): f"Supported architectures: {all_supported_archs}" ) - def _try_load_model_cls(self, model_arch: str) -> Optional[type[nn.Module]]: + def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None: if model_arch not in self.models: return None return _try_load_model_cls(model_arch, self.models[model_arch]) - def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]: + def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None: if model_arch not in self.models: return None @@ -767,7 +789,7 @@ def _try_resolve_transformers( self, architecture: str, model_config: ModelConfig, - ) -> Optional[str]: + ) -> str | None: if architecture in _TRANSFORMERS_BACKEND_MODELS: return architecture @@ -857,7 +879,7 @@ def _normalize_arch( def inspect_model_cls( self, - architectures: Union[str, list[str]], + architectures: str | list[str], model_config: ModelConfig, ) -> tuple[_ModelInfo, str]: if isinstance(architectures, str): @@ -909,7 +931,7 @@ def inspect_model_cls( def resolve_model_cls( self, - architectures: Union[str, list[str]], + architectures: str | list[str], model_config: ModelConfig, ) -> tuple[type[nn.Module], str]: if isinstance(architectures, str): @@ -963,7 +985,7 @@ def resolve_model_cls( def is_text_generation_model( self, - architectures: Union[str, list[str]], + architectures: str | list[str], model_config: ModelConfig, ) -> bool: model_cls, _ = self.inspect_model_cls(architectures, model_config) @@ -971,7 +993,7 @@ def is_text_generation_model( def is_pooling_model( self, - architectures: Union[str, list[str]], + architectures: str | list[str], model_config: ModelConfig, ) -> bool: model_cls, _ = self.inspect_model_cls(architectures, model_config) @@ -979,7 +1001,7 @@ def is_pooling_model( def is_cross_encoder_model( self, - architectures: Union[str, list[str]], + architectures: str | list[str], model_config: ModelConfig, ) -> bool: model_cls, _ = self.inspect_model_cls(architectures, model_config) @@ -987,7 +1009,7 @@ def is_cross_encoder_model( def is_multimodal_model( self, - architectures: Union[str, list[str]], + architectures: str | list[str], model_config: ModelConfig, ) -> bool: model_cls, _ = self.inspect_model_cls(architectures, model_config) @@ -995,7 +1017,7 @@ def is_multimodal_model( def is_multimodal_raw_input_only_model( self, - architectures: Union[str, list[str]], + architectures: str | list[str], model_config: ModelConfig, ) -> bool: model_cls, _ = self.inspect_model_cls(architectures, model_config) @@ -1003,7 +1025,7 @@ def is_multimodal_raw_input_only_model( def is_pp_supported_model( self, - architectures: Union[str, list[str]], + architectures: str | list[str], model_config: ModelConfig, ) -> bool: model_cls, _ = self.inspect_model_cls(architectures, model_config) @@ -1011,7 +1033,7 @@ def is_pp_supported_model( def model_has_inner_state( self, - architectures: Union[str, list[str]], + architectures: str | list[str], model_config: ModelConfig, ) -> bool: model_cls, _ = self.inspect_model_cls(architectures, model_config) @@ -1019,7 +1041,7 @@ def model_has_inner_state( def is_attention_free_model( self, - architectures: Union[str, list[str]], + architectures: str | list[str], model_config: ModelConfig, ) -> bool: model_cls, _ = self.inspect_model_cls(architectures, model_config) @@ -1027,7 +1049,7 @@ def is_attention_free_model( def is_hybrid_model( self, - architectures: Union[str, list[str]], + architectures: str | list[str], model_config: ModelConfig, ) -> bool: model_cls, _ = self.inspect_model_cls(architectures, model_config) @@ -1035,7 +1057,7 @@ def is_hybrid_model( def is_noops_model( self, - architectures: Union[str, list[str]], + architectures: str | list[str], model_config: ModelConfig, ) -> bool: model_cls, _ = self.inspect_model_cls(architectures, model_config) @@ -1043,7 +1065,7 @@ def is_noops_model( def is_transcription_model( self, - architectures: Union[str, list[str]], + architectures: str | list[str], model_config: ModelConfig, ) -> bool: model_cls, _ = self.inspect_model_cls(architectures, model_config) @@ -1051,20 +1073,12 @@ def is_transcription_model( def is_transcription_only_model( self, - architectures: Union[str, list[str]], + architectures: str | list[str], model_config: ModelConfig, ) -> bool: model_cls, _ = self.inspect_model_cls(architectures, model_config) return model_cls.supports_transcription_only - def is_v1_compatible( - self, - architectures: Union[str, list[str]], - model_config: ModelConfig, - ) -> bool: - model_cls, _ = self.inspect_model_cls(architectures, model_config) - return not model_cls.supports_v0_only - ModelRegistry = _ModelRegistry( { diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 6408cf7937b2..cfccb904f46c 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Optional, Union import torch from torch import nn @@ -68,7 +67,7 @@ def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: token_type_ids = _decode_token_type_ids(input_ids) @@ -106,15 +105,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @default_pooling_type("CLS") class RobertaEmbeddingModel(BertEmbeddingModel): - """A model that uses Roberta to provide embedding functionalities. - - This class encapsulates the BertModel and provides an interface for - embedding operations and customized pooling functions. - - Attributes: - model: An instance of BertModel used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ + """A model that uses Roberta to provide embedding functionalities.""" def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -124,8 +115,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: # Fix Roberta positions here outside of the CUDA graph. # Because we need the to extract the sequences from @@ -143,7 +134,7 @@ def forward( def _build_model( self, vllm_config: VllmConfig, prefix: str = "" - ) -> Union[BertModel, BertWithRope]: + ) -> BertModel | BertWithRope: if vllm_config.model_config.hf_config.position_embedding_type == "rotary": return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix) else: @@ -213,20 +204,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.pooler = DispatchPooler( { - "encode": Pooler.for_encode(pooler_config), + "token_classify": Pooler.for_token_classify( + pooler_config=pooler_config, classifier=self.classifier + ), "classify": ClassifierPooler( - pooling=CLSPool(), - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config - ), + pooling=CLSPool(), classifier=self.classifier, act_fn="classify" ), "score": ClassifierPooler( - pooling=CLSPool(), - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config - ), + pooling=CLSPool(), classifier=self.classifier, act_fn="score" ), } ) @@ -240,11 +225,11 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, ) -> torch.Tensor: replace_roberta_positions( input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx diff --git a/vllm/model_executor/models/rvl.py b/vllm/model_executor/models/rvl.py index 89150677f3ce..92352febe87e 100644 --- a/vllm/model_executor/models/rvl.py +++ b/vllm/model_executor/models/rvl.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping -from typing import Optional import torch import torch.nn as nn @@ -41,7 +40,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py index ca33a694a3b6..641160295afb 100644 --- a/vllm/model_executor/models/seed_oss.py +++ b/vllm/model_executor/models/seed_oss.py @@ -25,7 +25,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -75,7 +74,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -115,9 +114,9 @@ def __init__( head_dim: int, max_position: int = 4096 * 32, rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[tuple] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + rope_scaling: tuple | None = None, prefix: str = "", attn_type: str = AttentionType.DECODER, ) -> None: @@ -195,8 +194,8 @@ class SeedOssDecoderLayer(nn.Module): def __init__( self, config: SeedOssConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -243,7 +242,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -342,9 +341,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -477,9 +476,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -488,7 +487,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index ee21a03c8525..694e06f9fc81 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -4,14 +4,23 @@ within a vision language model.""" import math -from collections.abc import Iterable -from typing import Optional, Union +from collections.abc import Iterable, Mapping +from functools import cached_property +from typing import Annotated, Literal import torch from torch import nn -from transformers import SiglipVisionConfig +from transformers import ( + BatchFeature, + SiglipConfig, + SiglipProcessor, + SiglipTextConfig, + SiglipVisionConfig, +) from vllm.attention.layer import MultiHeadAttention +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import ( @@ -19,20 +28,232 @@ QKVParallelLinear, RowParallelLinear, ) +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalInputs, + MultiModalKwargsItems, + MultiModalUUIDDict, +) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptIndexTargets, + PromptReplacement, + PromptUpdate, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant +from .interfaces_base import default_pooling_type +from .utils import AutoWeightsLoader, maybe_prefix from .vision import ( VisionEncoderInfo, VisionFeatureSelectStrategy, + VisionFeatureSelectStrategyStr, + get_num_selected_vision_tokens, resolve_visual_encoder_outputs, ) +class SiglipImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image + """ + + type: Literal["pixel_values"] + data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] + + +_POOLING_TYPE_TO_STRATEGY: dict[str, VisionFeatureSelectStrategyStr] = { + "MEAN": "full", + "ALL": "full", + "CLS": "class", +} + + +def _get_vision_feature_select_strategy( + pooling_type: str, +) -> VisionFeatureSelectStrategyStr: + try: + return _POOLING_TYPE_TO_STRATEGY[pooling_type] + except KeyError: + raise ValueError( + f"No feature selection strategy is defined for " + f"pooling_type: {pooling_type!r}" + ) from None + + +class SiglipProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(SiglipConfig) + + def get_vision_encoder_info(self): + return SiglipEncoderInfo(self.get_hf_config()) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(SiglipProcessor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": 1} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + vision_encoder_info = self.get_vision_encoder_info() + + pooler_config = self.ctx.model_config.pooler_config + assert pooler_config is not None + + return get_num_selected_vision_tokens( + vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ), + _get_vision_feature_select_strategy(pooler_config.pooling_type), + ) + + def get_image_size_with_most_features(self) -> ImageSize: + vision_encoder_info = self.get_vision_encoder_info() + width = height = vision_encoder_info.get_image_size() + return ImageSize(width=width, height=height) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, image_height=target_height + ) + + +class SiglipDummyInputsBuilder(BaseDummyInputsBuilder[SiglipProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None + + return { + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) + } + + +class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]): + @cached_property + def image_token_id(self) -> int: + tokenizer = self.info.get_tokenizer() + dummy_token_id = 0 + + assert dummy_token_id not in tokenizer.all_special_ids + + return dummy_token_id + + def apply( + self, + prompt: str | list[int], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object] | None = None, + *, + mm_uuids: MultiModalUUIDDict | None = None, + ) -> MultiModalInputs: + if prompt and mm_data: + raise ValueError( + "Siglip accepts text-only or image-only inputs, not both! " + "Image-only inputs means passing an image with an empty text " + "prompt." + ) + + if mm_data: + # For multi-modal data, the prompt after processing should + # only contain the image token + tokenization_kwargs = { + **(tokenization_kwargs or {}), + "add_special_tokens": False, + } + + return super().apply( + prompt=prompt, + mm_data=mm_data, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, + ) + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> list[PromptUpdate]: + image_token_id = self.image_token_id + + def get_replacement(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, image_height=image_size.height + ) + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=PromptIndexTargets.start(), + replacement=get_replacement, + ), + ] + + class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): def get_num_image_tokens( self, @@ -152,8 +373,9 @@ def forward( class SiglipAttention(nn.Module): def __init__( self, - config: SiglipVisionConfig, - quant_config: Optional[QuantizationConfig] = None, + config: SiglipVisionConfig | SiglipTextConfig, + quant_config: QuantizationConfig | None = None, + *, prefix: str = "", ) -> None: super().__init__() @@ -196,12 +418,29 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, None]: """Input shape: Batch x Time x Channel""" qkv_states, _ = self.qkv_proj(hidden_states) query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + needs_unsqueeze = query_states.ndim == 2 + if needs_unsqueeze: + query_states, key_states, value_states = ( + query_states.unsqueeze(0), + key_states.unsqueeze(0), + value_states.unsqueeze(0), + ) + out = self.attn(query_states, key_states, value_states) + + if needs_unsqueeze: + out, query_states, key_states, value_states = ( + out.squeeze(0), + query_states.squeeze(0), + key_states.squeeze(0), + value_states.squeeze(0), + ) + attn_output, _ = self.out_proj(out) return attn_output, None @@ -210,8 +449,8 @@ def forward( class SiglipMLP(nn.Module): def __init__( self, - config: SiglipVisionConfig, - quant_config: Optional[QuantizationConfig] = None, + config: SiglipVisionConfig | SiglipTextConfig, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -250,8 +489,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class SiglipEncoderLayer(nn.Module): def __init__( self, - config: SiglipVisionConfig, - quant_config: Optional[QuantizationConfig] = None, + config: SiglipVisionConfig | SiglipTextConfig, + quant_config: QuantizationConfig | None = None, + *, prefix: str = "", ) -> None: super().__init__() @@ -292,9 +532,10 @@ def forward( class SiglipEncoder(nn.Module): def __init__( self, - config: SiglipVisionConfig, - quant_config: Optional[QuantizationConfig] = None, - num_hidden_layers_override: Optional[int] = None, + config: SiglipVisionConfig | SiglipTextConfig, + quant_config: QuantizationConfig | None = None, + num_hidden_layers_override: int | None = None, + *, prefix: str = "", ) -> None: super().__init__() @@ -321,7 +562,7 @@ def forward( self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool, - ) -> Union[torch.Tensor, list[torch.Tensor]]: + ) -> torch.Tensor | list[torch.Tensor]: hidden_states_pool = [inputs_embeds] hidden_states = inputs_embeds @@ -336,13 +577,83 @@ def forward( return hidden_states +class SiglipTextTransformer(nn.Module): + def __init__( + self, + config: SiglipTextConfig, + quant_config: QuantizationConfig | None = None, + *, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipTextEmbeddings(config) + + self.encoder = SiglipEncoder( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + ) + + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.head = nn.Linear(embed_dim, config.projection_size) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings.token_embedding(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + hidden_states = self.embeddings(input_ids, position_ids, inputs_embeds) + + last_hidden_state = self.encoder( + inputs_embeds=hidden_states, return_all_hidden_states=False + ) + + last_hidden_state = self.final_layer_norm(last_hidden_state) + + return last_hidden_state + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + class SiglipMultiheadAttentionPoolingHead(nn.Module): """Multihead Attention Pooling.""" def __init__( self, config: SiglipVisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -358,8 +669,9 @@ def __init__( ) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: - batch_size = hidden_state.shape[0] - probe = self.probe.repeat(batch_size, 1, 1) + batch_size = hidden_state.size(0) + + probe = self.probe.expand(batch_size, -1, -1) hidden_state = self.attention(probe, hidden_state, hidden_state)[0] @@ -368,17 +680,19 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.mlp(hidden_state) hidden_state += residual - return hidden_state[:, 0] + pooled = hidden_state[:, 0] + + return pooled.unsqueeze(1) class SiglipVisionTransformer(nn.Module): def __init__( self, config: SiglipVisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, - num_hidden_layers_override: Optional[int] = None, - require_post_norm: Optional[bool] = None, + num_hidden_layers_override: int | None = None, + require_post_norm: bool | None = None, prefix: str = "", ) -> None: super().__init__() @@ -421,19 +735,26 @@ def __init__( prefix=f"{prefix}.head", ) + @property + def dtype(self): + return next(self.parameters()).dtype + + @property + def device(self): + return next(self.parameters()).device + def forward( self, pixel_values: torch.Tensor, *, interpolate_pos_encoding: bool = False, - select_layers: Optional[list[int]] = None, - feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, + select_layers: list[int] | None = None, + feature_select_strategy: VisionFeatureSelectStrategy | None = None, ) -> torch.Tensor: hidden_states = self.embeddings( pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, ) - # Produces either the last layer output or all of the hidden states, # depending on if we have select_layers or not encoder_outputs = self.encoder( @@ -441,21 +762,60 @@ def forward( return_all_hidden_states=select_layers is not None, ) - # Handle post-norm (if applicable) and stacks feature layers if needed + if self.post_layernorm is not None: + encoder_outputs = self.post_layernorm(encoder_outputs) + + if self.use_head: + encoder_outputs = self.head(encoder_outputs) + + # stacks feature layers if needed encoder_outputs = resolve_visual_encoder_outputs( encoder_outputs, - self.post_layernorm, + None, select_layers=select_layers, max_possible_layers=self.config.num_hidden_layers, feature_select_strategy=feature_select_strategy, ) - # TODO: add this back when pooled_output is used in inference. - # if self.use_head: - # pooled_output = self.head(encoder_outputs) - return encoder_outputs + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + layer_count = len(self.encoder.layers) + + for name, loaded_weight in weights: + # post_layernorm is not needed in SiglipVisionTransformer + if name.startswith("post_layernorm") and self.post_layernorm is None: + continue + + # omit layers when num_hidden_layers_override is set + if name.startswith("encoder.layers"): + layer_idx = int(name.split(".")[2]) + if layer_idx >= layer_count: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class SiglipVisionModel(nn.Module): config_class = SiglipVisionConfig @@ -464,10 +824,10 @@ class SiglipVisionModel(nn.Module): def __init__( self, config: SiglipVisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, *, - num_hidden_layers_override: Optional[int] = None, - require_post_norm: Optional[bool] = None, + num_hidden_layers_override: int | None = None, + require_post_norm: bool | None = None, prefix: str = "", ) -> None: super().__init__() @@ -485,14 +845,18 @@ def get_input_embeddings(self) -> nn.Module: @property def dtype(self): - return self.get_input_embeddings().weight.dtype + return self.vision_model.dtype + + @property + def device(self): + return self.vision_model.device def forward( self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False, - select_layers: Optional[list[int]] = None, - feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, + select_layers: list[int] | None = None, + feature_select_strategy: VisionFeatureSelectStrategy | None = None, ) -> torch.Tensor: return self.vision_model( pixel_values=pixel_values, @@ -556,3 +920,214 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + +# Adapted from: https://github.com/huggingface/transformers/blob/v4.54.1/src/transformers/models/siglip/modeling_siglip.py#L200 +class SiglipTextEmbeddings(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + self.config = config + + self.token_embedding = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) + + self.position_embedding = VocabParallelEmbedding( + config.max_position_embeddings, config.hidden_size + ) + + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).expand((1, -1)), + persistent=False, + ) + + def forward( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + return embeddings + + +# Assume EOS token corresponds to CLS token in text model +@default_pooling_type("CLS") +@MULTIMODAL_REGISTRY.register_processor( + SiglipMultiModalProcessor, + info=SiglipProcessingInfo, + dummy_inputs=SiglipDummyInputsBuilder, +) +class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): + is_pooling_model = True + + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + merge_by_field_config = True + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return None + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: SiglipConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + + if hasattr(config, "num_labels"): + config.num_labels = 0 + + text_config = config.text_config + vision_config = config.vision_config + + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = SiglipTextTransformer( + text_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "text_model"), + ) + self.vision_model = SiglipVisionTransformer( + vision_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_model"), + ) + + self.text_projection_size = text_config.projection_size + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + self.pooler_config = pooler_config + + self.pooler = DispatchPooler( + { + "token_embed": Pooler.for_token_embed(pooler_config), + "embed": Pooler.for_embed(pooler_config), + } + ) + + self._is_text_input = True + + def get_text_features( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + last_hidden_state = self.text_model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + text_features = self.text_model.head(last_hidden_state) + # Flip to extract CLS token (first token after reversal) for pooling + text_features = text_features.flip(0) + return text_features + + def get_image_features( + self, + pixel_values: torch.Tensor, + feature_select_strategy: VisionFeatureSelectStrategy | None = None, + ) -> torch.Tensor: + if feature_select_strategy is None: + feature_select_strategy = _get_vision_feature_select_strategy( + self.pooler_config.pooling_type + ) + + pooled_output = self.vision_model( + pixel_values=pixel_values, + select_layers=None, + feature_select_strategy=feature_select_strategy, + ) + + return pooled_output + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> SiglipImagePixelInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + if pixel_values is None: + return None + + expected_h = expected_w = self.config.vision_config.image_size + return SiglipImagePixelInputs( + type="pixel_values", + data=pixel_values, + resolve_bindings={"h": expected_h, "w": expected_w}, + ) + + def _process_image_inputs(self, inputs: SiglipImagePixelInputs) -> torch.Tensor: + pixel_values = inputs["data"] + + return self.get_image_features(pixel_values) + + def get_language_model(self) -> torch.nn.Module: + return self.text_model + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + self._is_text_input = ( + multimodal_embeddings is None or len(multimodal_embeddings) == 0 + ) + + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + + vision_embeddings = self._process_image_inputs(image_input) + return vision_embeddings + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor: + if intermediate_tensors is not None: + raise RuntimeError("PP is not supported for this model") + + # Multimodal inputs (image embeddings) + if not self._is_text_input: + return inputs_embeds + + return self.get_text_features(input_ids, positions, inputs_embeds) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_substrs=[".position_ids"], + ignore_unexpected_prefixes=["logit_scale.", "logit_bias."], + ) + + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index 7cd133d9da1d..bab5c1d82ded 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -4,7 +4,6 @@ within a vision language model.""" from collections.abc import Iterable -from typing import Optional import torch from einops import rearrange, repeat @@ -82,7 +81,7 @@ def __init__(self, config: PretrainedConfig): def forward( self, pixel_values: torch.FloatTensor, - grid_thws: Optional[torch.LongTensor] = None, + grid_thws: torch.LongTensor | None = None, ) -> torch.Tensor: """ Args: @@ -206,9 +205,10 @@ class Siglip2Attention(nn.Module): def __init__( self, config: Siglip2VisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ): super().__init__() self.config = config @@ -249,7 +249,9 @@ def __init__( # Detect attention implementation. self.attn_backend = get_vit_attn_backend( - head_size=self.head_dim, dtype=torch.get_default_dtype() + head_size=self.head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) self.use_upstream_fa = False @@ -257,6 +259,7 @@ def __init__( maybe_get_vit_flash_attn_backend( self.attn_backend, self.use_upstream_fa, + attn_backend_override=attn_backend_override, ) ) @@ -275,8 +278,8 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: """Input shape: Batch x Time x Channel""" seq_length, embed_dim = hidden_states.shape @@ -337,7 +340,7 @@ class Siglip2MLP(nn.Module): def __init__( self, config: Siglip2VisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ): @@ -370,9 +373,10 @@ class Siglip2EncoderLayer(nn.Module): def __init__( self, config: Siglip2VisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -382,6 +386,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.self_attn", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Siglip2MLP( @@ -432,9 +437,10 @@ class Siglip2Encoder(nn.Module): def __init__( self, config: Siglip2VisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ): super().__init__() self.config = config @@ -445,6 +451,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.layers.{idx}", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) for idx in range(config.num_hidden_layers) ] @@ -592,7 +599,7 @@ def forward( # for more information dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32, ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) reverse_indices = torch.argsort(window_index) @@ -616,9 +623,10 @@ class Siglip2VisionTransformer(nn.Module): def __init__( self, config: Siglip2VisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ): super().__init__() self.config = config @@ -630,6 +638,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.encoder", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) @@ -655,9 +664,10 @@ class Siglip2NavitModel(torch.nn.Module): def __init__( self, config: Siglip2VisionConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ): super().__init__() @@ -666,6 +676,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.vision_model", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) def forward( diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index f0f6917ddf91..44550ae595d1 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -8,7 +8,7 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Literal, TypeAlias import torch import torch.nn as nn @@ -96,14 +96,14 @@ class SkyworkR1VImageEmbeddingInputs(TensorSchema): type: Literal["image_embeds"] = "image_embeds" data: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], + torch.Tensor | list[torch.Tensor], TensorShape("ni", "ifs", "hs"), ] -SkyworkR1VImageInputs = Union[ - SkyworkR1VImagePixelInputs, SkyworkR1VImageEmbeddingInputs -] +SkyworkR1VImageInputs: TypeAlias = ( + SkyworkR1VImagePixelInputs | SkyworkR1VImageEmbeddingInputs +) # adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/ @@ -284,9 +284,9 @@ def __init__( config: PretrainedConfig, tokenizer: AnyTokenizer, *, - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, ) -> None: super().__init__() @@ -324,7 +324,7 @@ def image_token_id(self) -> int: def get_image_repl( self, feature_size: int, - num_patches: Optional[int], + num_patches: int | None, ) -> PromptUpdateDetails[str]: repl_features = IMG_CONTEXT * feature_size repl_full = IMG_START + repl_features + IMG_END @@ -334,10 +334,10 @@ def get_image_repl( def resolve_min_max_num( self, *, - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, - use_thumbnail: Optional[bool] = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, + use_thumbnail: bool | None = None, ) -> tuple[int, int]: min_dynamic_patch = ( self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch @@ -362,10 +362,10 @@ def resolve_min_max_num( def resolve_target_ratios( self, *, - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, - use_thumbnail: Optional[bool] = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, + use_thumbnail: bool | None = None, ) -> list[tuple[int, int]]: min_num, max_num = self.resolve_min_max_num( min_dynamic_patch=min_dynamic_patch, @@ -399,9 +399,9 @@ def get_num_image_tokens( def _images_to_pixel_values_lst( self, images: list[Image.Image], - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, ) -> list[torch.Tensor]: min_num, max_num = self.resolve_min_max_num( min_dynamic_patch=min_dynamic_patch, @@ -423,12 +423,12 @@ def _images_to_pixel_values_lst( def __call__( self, - text: Optional[Union[str, list[str]]] = None, - images: Optional[Union[Image.Image, list[Image.Image]]] = None, - min_dynamic_patch: Optional[int] = None, - max_dynamic_patch: Optional[int] = None, - dynamic_image_size: Optional[bool] = None, - return_tensors: Optional[Union[str, TensorType]] = None, + text: str | list[str] | None = None, + images: Image.Image | list[Image.Image] | None = None, + min_dynamic_patch: int | None = None, + max_dynamic_patch: int | None = None, + dynamic_image_size: bool | None = None, + return_tensors: str | TensorType | None = None, ) -> BatchFeature: if text is None: text = [] @@ -479,7 +479,7 @@ def get_hf_processor(self, **kwargs: object) -> SkyworkR1VProcessor: **kwargs, ) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_num_image_tokens( @@ -487,7 +487,7 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - processor: Optional[SkyworkR1VProcessor], + processor: SkyworkR1VProcessor | None, ) -> int: if processor is None: processor = self.get_hf_processor() @@ -532,7 +532,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) @@ -650,7 +650,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): merge_by_field_config = True @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<image>" @@ -691,7 +691,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: prefix=maybe_prefix(prefix, "language_model"), ) - self.mlp1 = self._init_mlp1(config) + self.mlp1 = self._init_mlp1( + config, quant_config, prefix=maybe_prefix(prefix, "mlp1") + ) self.img_context_token_id = None self.visual_token_mask = None @@ -715,7 +717,7 @@ def _patch_quant_config( def _init_vision_model( self, config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, *, is_mono: bool, prefix: str, @@ -738,7 +740,12 @@ def _init_vision_model( else: return InternVisionPatchModel(config.vision_config) - def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: + def _init_mlp1( + self, + config: PretrainedConfig, + quant_config: QuantizationConfig, + prefix: str = "", + ) -> nn.Module: vit_hidden_size = config.vision_config.hidden_size llm_hidden_size = config.text_config.hidden_size @@ -748,9 +755,17 @@ def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size, return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.1", ), nn.GELU(), - ReplicatedLinear(llm_hidden_size, llm_hidden_size, return_bias=False), + ReplicatedLinear( + llm_hidden_size, + llm_hidden_size, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.3", + ), ) def pixel_shuffle(self, x, scale_factor=0.5): @@ -784,7 +799,7 @@ def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[SkyworkR1VImageInputs]: + ) -> SkyworkR1VImageInputs | None: pixel_values_flat = kwargs.pop("pixel_values_flat", None) image_num_patches = kwargs.pop("image_num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -799,8 +814,11 @@ def _parse_and_validate_image_input( ) image_token_id = kwargs["image_token_id"] - assert isinstance(image_token_id, torch.Tensor) - self.img_context_token_id = image_token_id.flatten().unique().item() + if isinstance(image_token_id, torch.Tensor): + image_token_id = image_token_id.flatten().unique().item() + + assert isinstance(image_token_id, int) + self.img_context_token_id = image_token_id if pixel_values_flat is not None: return SkyworkR1VImagePixelInputs( @@ -818,7 +836,7 @@ def _parse_and_validate_image_input( def _process_image_input( self, image_input: SkyworkR1VImageInputs, - ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]: + ) -> torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...]: if image_input["type"] == "image_embeds": return image_input["data"] @@ -864,9 +882,9 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, *, - is_multimodal: Optional[torch.Tensor] = None, + is_multimodal: torch.Tensor | None = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: @@ -887,8 +905,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> IntermediateTensors: if intermediate_tensors is not None: @@ -913,7 +931,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/smolvlm.py b/vllm/model_executor/models/smolvlm.py index 1800330c8235..e8b805297d96 100644 --- a/vllm/model_executor/models/smolvlm.py +++ b/vllm/model_executor/models/smolvlm.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional from transformers import SmolVLMProcessor @@ -17,9 +16,7 @@ class SmolVLMProcessingInfo(Idefics3ProcessingInfo): def get_hf_processor(self, **kwargs: object) -> SmolVLMProcessor: return self.ctx.get_hf_processor(SmolVLMProcessor, **kwargs) - def _get_image_token( - self, processor: Optional[SmolVLMProcessor] - ) -> tuple[str, str]: + def _get_image_token(self, processor: SmolVLMProcessor | None) -> tuple[str, str]: if processor is None: processor = self.get_hf_processor() image_token = processor.image_token diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 5abcb47c6e25..f0dfce7bc7b6 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -25,7 +25,7 @@ """Inference-only Solar model compatible with HuggingFace weights.""" from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -73,7 +73,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, prefix: str = "", ) -> None: @@ -113,11 +113,11 @@ def __init__( num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, bias: bool = False, - cache_config: Optional[CacheConfig] = None, + cache_config: CacheConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -197,8 +197,8 @@ class SolarDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -250,7 +250,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: @@ -322,11 +322,11 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -501,9 +501,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: model_output = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 79ed00183344..a4e309e0aa6b 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -24,7 +24,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -63,7 +62,7 @@ class StablelmMLP(nn.Module): def __init__( self, config: StableLmConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -97,8 +96,8 @@ class StablelmAttention(nn.Module): def __init__( self, config: StableLmConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -184,8 +183,8 @@ class StablelmDecoderLayer(nn.Module): def __init__( self, config: StableLmConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -254,9 +253,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -340,9 +339,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -351,7 +350,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index ec894140c3bf..d147237808c2 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -23,7 +23,6 @@ from collections.abc import Iterable from itertools import islice -from typing import Optional, Union import torch from torch import nn @@ -67,8 +66,8 @@ class Starcoder2Attention(nn.Module): def __init__( self, config: Starcoder2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -147,7 +146,7 @@ class Starcoder2MLP(nn.Module): def __init__( self, config: Starcoder2Config, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -178,8 +177,8 @@ class Starcoder2DecoderLayer(nn.Module): def __init__( self, config: Starcoder2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -258,9 +257,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -347,9 +346,9 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) @@ -358,7 +357,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 2099055e641c..a2a1bfd30d8d 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -4,7 +4,7 @@ from collections.abc import Iterable from itertools import islice -from typing import Any, Optional +from typing import Any import torch from torch import nn @@ -54,7 +54,7 @@ class FusedMoEBlock(nn.Module): def __init__( self, config: ModelConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -106,7 +106,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -146,12 +146,12 @@ def __init__( num_kv_heads: int, norm_eps: float, rope_theta: int, - share_q_dim: Optional[int] = None, - rope_scaling: Optional[dict[str, Any]] = None, + share_q_dim: int | None = None, + rope_scaling: dict[str, Any] | None = None, max_position_embedding: int = 8192, head_dim: int = 256, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -229,8 +229,8 @@ class Step3TextDecoderLayer(nn.Module): def __init__( self, config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -291,7 +291,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states @@ -362,8 +362,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -436,8 +436,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, ): hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 5ec7845a122f..dbb549ba3f98 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -4,7 +4,7 @@ from collections.abc import Iterable, Mapping, Sequence from itertools import product from math import ceil, sqrt -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal, TypeAlias import numpy as np import torch @@ -71,7 +71,7 @@ class Step3VLImagePixelInputs(TensorSchema): type: Literal["pixel_values"] pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] patch_pixel_values: Annotated[ - Optional[torch.Tensor], TensorShape("bnp", 3, "hp", "wp") + torch.Tensor | None, TensorShape("bnp", 3, "hp", "wp") ] num_patches: Annotated[torch.Tensor, TensorShape("bn")] @@ -88,7 +88,7 @@ class Step3VLImageEmbeddingInputs(TensorSchema): data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")] -Step3VLImageInputs = Union[Step3VLImagePixelInputs, Step3VLImageEmbeddingInputs] +Step3VLImageInputs: TypeAlias = Step3VLImagePixelInputs | Step3VLImageEmbeddingInputs ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None] @@ -409,7 +409,7 @@ def _get_image_repl_features( self, num_images: int, num_patches: int, - patch_new_line_idx: Optional[list[bool]], + patch_new_line_idx: list[bool] | None, ) -> tuple[str, list[int]]: if num_patches > 0: patch_repl, patch_repl_ids = self._get_patch_repl( @@ -438,9 +438,9 @@ def replace_placeholder(self, text: str, placeholder: str, repls: list[str]) -> def __call__( self, - text: Optional[Union[str, list[str]]] = None, - images: Optional[Union[Image.Image, list[Image.Image]]] = None, - return_tensors: Optional[Union[str, TensorType]] = None, + text: str | list[str] | None = None, + images: Image.Image | list[Image.Image] | None = None, + return_tensors: str | TensorType | None = None, ) -> BatchFeature: if text is None: text = [] @@ -513,7 +513,7 @@ def get_hf_processor(self) -> Step3VLProcessor: self.get_tokenizer(), ) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_max_image_tokens(self) -> int: @@ -556,7 +556,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) @@ -716,7 +716,7 @@ class Step3VisionAttention(nn.Module): def __init__( self, config, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ): @@ -778,7 +778,7 @@ class Step3VisionMLP(nn.Module): def __init__( self, config, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ): @@ -813,7 +813,7 @@ class Step3VisionEncoderLayer(nn.Module): def __init__( self, config: Step3VisionEncoderConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ): @@ -848,7 +848,7 @@ class Step3VisionEncoder(nn.Module): def __init__( self, config: Step3VisionEncoderConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ): @@ -881,7 +881,7 @@ class Step3VisionTransformer(nn.Module): def __init__( self, config: Step3VisionEncoderConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, ): @@ -927,7 +927,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) supports_encoder_tp_data = True @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<im_patch>" @@ -994,7 +994,7 @@ def dtype(self): def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[Step3VLImageInputs]: + ) -> Step3VLImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) patch_pixel_values = kwargs.pop("patch_pixel_values", None) num_patches = kwargs.pop("num_patches", None) @@ -1085,9 +1085,9 @@ def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings: def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, *, - is_multimodal: Optional[torch.Tensor] = None, + is_multimodal: torch.Tensor | None = None, # Multi-modal token ID may exceed vocab size handle_oov_mm_token: bool = True, ) -> torch.Tensor: @@ -1106,10 +1106,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None elif inputs_embeds is None: @@ -1130,7 +1130,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/swin.py b/vllm/model_executor/models/swin.py index 485c008e830a..a74fd80c06d8 100644 --- a/vllm/model_executor/models/swin.py +++ b/vllm/model_executor/models/swin.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn @@ -28,7 +27,7 @@ def __init__( dim: int, num_heads: int, window_size: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -102,9 +101,9 @@ def _get_rel_pos_bias(self) -> torch.Tensor: def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = False, + attention_mask: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + output_attentions: bool | None = False, ) -> tuple[torch.Tensor, ...]: batch_size, dim, num_channels = hidden_states.shape @@ -155,7 +154,7 @@ def __init__( self, config: SwinConfig, dim: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -181,7 +180,7 @@ def __init__( dim: int, num_heads: int, window_size: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -201,9 +200,9 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = False, + attention_mask: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + output_attentions: bool | None = False, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, attention_mask, head_mask, output_attentions @@ -218,7 +217,7 @@ def __init__( self, config: SwinConfig, dim: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -241,7 +240,7 @@ def __init__( self, config: SwinConfig, dim: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -266,7 +265,7 @@ def __init__( num_heads: int, drop_path_rate: float = 0.0, shift_size: int = 0, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__( @@ -303,8 +302,8 @@ def __init__( depth: int, num_heads: int, drop_path: list[float], - downsample: Optional[SwinPatchMerging] = None, - quant_config: Optional[QuantizationConfig] = None, + downsample: SwinPatchMerging | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -340,9 +339,9 @@ def forward( self, hidden_states: torch.Tensor, input_dimensions: tuple[int, int], - head_mask: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = False, - always_partition: Optional[bool] = False, + head_mask: torch.FloatTensor | None = None, + output_attentions: bool | None = False, + always_partition: bool | None = False, ) -> tuple[torch.Tensor]: height, width = input_dimensions for i, layer_module in enumerate(self.blocks): @@ -384,7 +383,7 @@ def __init__( self, config: SwinConfig, grid_size: int, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -426,9 +425,9 @@ def forward( self, hidden_states: torch.Tensor, input_dimensions: tuple[int, int], - head_mask: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = False, - always_partition: Optional[bool] = False, + head_mask: torch.FloatTensor | None = None, + output_attentions: bool | None = False, + always_partition: bool | None = False, ) -> tuple[torch.Tensor]: for i, layer_module in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None @@ -455,7 +454,7 @@ class SwinModel(nn.Module): def __init__( self, config: SwinConfig, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -473,9 +472,9 @@ def __init__( def forward( self, - pixel_values: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, + pixel_values: torch.FloatTensor | None = None, + head_mask: torch.FloatTensor | None = None, + output_attentions: bool | None = None, ) -> tuple[torch.Tensor]: embedding_output, input_dimensions = self.embeddings(pixel_values) diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 6a224fe9288b..bfa1b5bbaf84 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Final, Literal, Optional, Protocol, TypeVar, Union +from typing import Annotated, Final, Literal, Protocol, TypeAlias, TypeVar import torch import torch.nn as nn @@ -81,7 +81,7 @@ class TarsierImageEmbeddingInputs(TensorSchema): data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] -TarsierImageInputs = Union[TarsierImagePixelInputs, TarsierImageEmbeddingInputs] +TarsierImageInputs: TypeAlias = TarsierImagePixelInputs | TarsierImageEmbeddingInputs class TarsierHfConfig(Protocol): # Based on the Tarsier's LlavaConfig @@ -89,7 +89,7 @@ class TarsierHfConfig(Protocol): # Based on the Tarsier's LlavaConfig text_config: Final[PretrainedConfig] # Added from Tarsier's LlavaConfig image_token_index: Final[int] vision_feature_select_strategy: Final[str] - vision_feature_layer: Final[Union[int, list[int]]] + vision_feature_layer: Final[int | list[int]] projector_hidden_act: Final[str] image_newline_idx: Final[int] image_new_idx: Final[int] @@ -109,9 +109,10 @@ class TarsierProcessor(LlavaProcessor): def __call__( self, images: ImageInput = None, - text: Union[ - TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput] - ] = None, + text: TextInput + | PreTokenizedInput + | list[TextInput] + | list[PreTokenizedInput] = None, audio=None, videos=None, **kwargs: Unpack[TarsierProcessorKwargs], @@ -173,7 +174,7 @@ def __init__( text_hidden_size: int, projector_hidden_act: str, multimodal_projector_bias: bool, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -215,7 +216,7 @@ def get_hf_processor(self, **kwargs: object) -> TarsierProcessor: return self.ctx.get_hf_processor(TarsierProcessor, **kwargs) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} def get_num_image_tokens( @@ -331,7 +332,7 @@ def _build_tarsier_hf_processor( info: _I_Tarsier, dummy_inputs: BaseDummyInputsBuilder[_I_Tarsier], *, - cache: Optional[BaseMultiModalProcessorCache] = None, + cache: BaseMultiModalProcessorCache | None = None, ) -> BaseMultiModalProcessor: if isinstance(info, TarsierProcessingInfo): return TarsierMultiModalProcessor( @@ -344,11 +345,11 @@ def _build_tarsier_hf_processor( def init_vision_tower_for_tarsier( hf_config: TarsierHfConfig, # Use the Tarsier specific config protocol - quant_config: Optional[QuantizationConfig], + quant_config: QuantizationConfig | None, *, - require_post_norm: Optional[bool] = None, + require_post_norm: bool | None = None, prefix: str = "", -) -> Union[CLIPVisionModel, SiglipVisionModel]: +) -> CLIPVisionModel | SiglipVisionModel: vision_config = hf_config.vision_config feature_layers = hf_config.vision_feature_layer @@ -407,7 +408,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) } @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "<image>" @@ -456,7 +457,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[TarsierImageInputs]: + ) -> TarsierImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -479,9 +480,9 @@ def _parse_and_validate_image_input( def _image_pixels_to_features( self, - vision_tower: Union[CLIPVisionModel, SiglipVisionModel], - pixel_values: Union[torch.Tensor, list[torch.Tensor]], - ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + vision_tower: CLIPVisionModel | SiglipVisionModel, + pixel_values: torch.Tensor | list[torch.Tensor], + ) -> torch.Tensor | tuple[torch.Tensor, ...]: # From vLLM LLaVA, vision tower output handling return vision_tower( pixel_values, @@ -540,7 +541,7 @@ def _add_tarsier_split_tokens( def _process_image_pixels( self, inputs: TarsierImagePixelInputs, - ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + ) -> torch.Tensor | tuple[torch.Tensor, ...]: assert self.vision_tower is not None pixel_values = inputs["pixel_values"] image_features_selected = self._image_pixels_to_features( @@ -559,7 +560,7 @@ def _process_image_pixels( def _process_image_input( self, image_input: TarsierImageInputs, - ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + ) -> torch.Tensor | tuple[torch.Tensor, ...]: if image_input["type"] == "image_embeds": projected_features = image_input["data"] if isinstance(projected_features, torch.Tensor): @@ -585,10 +586,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None elif inputs_embeds is None: @@ -610,7 +611,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index 13d2e8eacc01..e799e41e2c38 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -18,8 +18,8 @@ """Wrapper around `Terratorch` models""" from collections import OrderedDict -from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Iterable, Mapping, Sequence +from typing import Any import torch import torch.nn as nn @@ -34,7 +34,7 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import DispatchPooler, Pooler +from vllm.model_executor.layers.pooler import DispatchPooler, DummyPooler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.multimodal import MULTIMODAL_REGISTRY @@ -96,7 +96,7 @@ def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]): class TerratorchProcessingInfo(BaseProcessingInfo): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} @@ -114,7 +114,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: # Dummy data is generated based on the 'input' section # defined in the HF configuration file @@ -136,8 +136,8 @@ def __init__(self, pretrained_cfg: dict, *args, **kwargs): def _parse_image_data( self, - data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], - ) -> Optional[ModalityDataItems[Any, Any]]: + data: dict[str, torch.Tensor] | ModalityData[ImageItem], + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): terratorch_fields = _terratorch_field_names(self._pretrained_cfg) @@ -157,7 +157,7 @@ def __init__( info: TerratorchProcessingInfo, dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]", *, - cache: Optional[MultiModalProcessorOnlyCache] = None, + cache: MultiModalProcessorOnlyCache | None = None, ) -> None: self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"] super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache) @@ -182,11 +182,11 @@ def _get_prompt_updates( def apply( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Optional[Mapping[str, object]] = None, - mm_uuids: Optional[MultiModalUUIDDict] = None, + tokenization_kwargs: Mapping[str, object] | None = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> MultiModalInputs: if "image" in mm_data: image_data = mm_data["image"] @@ -232,7 +232,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): is_pooling_model = True @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return None @@ -249,16 +249,14 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}, - ) + self.pooler = DispatchPooler({"plugin": DummyPooler()}) def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, *, - is_multimodal: Optional[torch.Tensor] = None, + is_multimodal: torch.Tensor | None = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: # We do not really use any input tokens and therefore no embeddings @@ -269,10 +267,10 @@ def get_input_embeddings( def forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor | None, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ): model_output = self.inference_runner.forward(**kwargs) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py deleted file mode 100644 index 47e829861284..000000000000 --- a/vllm/model_executor/models/transformers.py +++ /dev/null @@ -1,948 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Copyright 2024 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Wrapper around `transformers` models""" - -from collections.abc import Iterable, Mapping -from contextlib import contextmanager -from pathlib import Path -from typing import Literal, Optional, Union - -import regex as re -import torch -import transformers -from packaging.version import Version -from torch import nn -from transformers import AutoModel, BatchFeature, PretrainedConfig, PreTrainedModel -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS - -from vllm.attention import Attention, AttentionType -from vllm.compilation.decorators import support_torch_compile -from vllm.config import ( - CacheConfig, - DeviceConfig, - ModelConfig, - ParallelConfig, - VllmConfig, -) -from vllm.config.multimodal import BaseDummyOptions -from vllm.config.utils import getattr_iter -from vllm.distributed import get_pp_group, get_tp_group -from vllm.distributed.utils import get_pp_indices -from vllm.logger import init_logger -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems -from vllm.multimodal.inputs import ( - MultiModalDataDict, - MultiModalFieldConfig, - MultiModalInputs, - MultiModalUUIDDict, - PlaceholderRange, -) -from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems -from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo -from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.sequence import IntermediateTensors - -from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant -from .utils import ( - AutoWeightsLoader, - PPMissingLayer, - WeightsMapper, - make_empty_intermediate_tensors_factory, - maybe_prefix, -) - -logger = init_logger(__name__) - - -def get_feature_request_tip( - model: str, - trust_remote_code: bool, -) -> str: - hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new" - gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose" - url = hf_url if trust_remote_code else gh_url - prefix = f"Please open {url} to request support for this feature. " - if Path(model).exists(): - prefix = "" - doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models" - tip = f"See {doc_url} for instructions on how to add support yourself." - return f"{prefix}{tip}" - - -def vllm_flash_attention_forward( - # Transformers args - module: torch.nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor, - # Transformers kwargs - scaling: Optional[float] = None, - # vLLM kwargs - attention_instances: Optional[dict[Attention]] = None, - **kwargs, -): - self_attn = attention_instances[module.layer_idx] - if scaling is not None: - self_attn.impl.scale = float(scaling) - hidden = query.shape[-2] - query, key, value = (x.transpose(1, 2) for x in (query, key, value)) - query, key, value = (x.reshape(hidden, -1) for x in (query, key, value)) - return self_attn.forward(query, key, value), None - - -ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward - - -def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): - logger.debug("%s: %s -> %s", name, old_module, new_module) - - -def can_enable_torch_compile(vllm_config: VllmConfig) -> bool: - """ - Callable to be passed to `@support_torch_compile`'s `enable_if` argument. - - Defaults to `True` but is disabled in the following situations: - - - The model uses dynamic rope scaling. - """ - enable = True - text_config = vllm_config.model_config.hf_config.get_text_config() - # Dynamic rope scaling is not compatible with torch.compile - rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {} - if rope_scaling.get("rope_type") == "dynamic": - enable = False - return enable - - -Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"] - - -def replace_linear_class( - linear: nn.Linear, - style: Style = "replicate", - quant_config: Optional[QuantizationConfig] = None, - *, - prefix: str = "", -) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]: - """ - Replace nn.Linear with one of vLLM's tensor parallel linear classes. - - Args: - linear: `nn.Linear` to be replaced. - style: Tensor parallel style of the new linear, e.g. "colwise". - quant_config: Quantization config for the new linear. - Returns: - The new linear. - """ - - if not isinstance(style, str): - raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") - - vllm_linear_cls, vllm_linear_kwargs = { - "colwise": (ColumnParallelLinear, {}), - "colwise_rep": (ColumnParallelLinear, {"gather_output": True}), - "rowwise": (RowParallelLinear, {}), - "rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}), - "replicate": (ReplicatedLinear, {}), - }.get(style, (ReplicatedLinear, {})) - - return vllm_linear_cls( - input_size=linear.in_features, - output_size=linear.out_features, - bias=linear.bias is not None, - quant_config=quant_config, - prefix=prefix, - return_bias=False, - **vllm_linear_kwargs, - ) - - -def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm: - """Replace a Transformers RMSNorm with vLLM's RMSNorm. - - This method assumes: - - Weight is stored as `weight`. - - Epsilon is stored as `eps` or `variance_epsilon`. - - `with_scale` indicates whether the layer has a weight (Gemma3n only). - - `var_hidden_size` is only ever used for Intern vision encoder in vLLM - and Transformers doesn't appear to have the same concept. - """ - kwargs = { - "hidden_size": hidden_size, - "eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6), - "has_weight": getattr(rms_norm, "with_scale", True), - } - if (weight := getattr(rms_norm, "weight", None)) is not None: - # If weight is a Parameter, get its data tensor - weight = getattr(weight, "data", weight) - kwargs["dtype"] = weight.dtype - else: - # No weight, fall back to weightless RMSNorm - kwargs["has_weight"] = False - return RMSNorm(**kwargs) - - -# Copied from `accelerate` -@contextmanager -def init_on_device_without_buffers(device: torch.device): - """ - A context manager under which models are initialized with all - parameters on the specified device. However buffers are not - initialized on specified device. - - Args: - device (`torch.device`): - Device to initialize all parameters on. - """ - - old_register_parameter = nn.Module.register_parameter - - def register_empty_parameter(module, name, param): - old_register_parameter(module, name, param) - if param is not None: - param_cls = type(module._parameters[name]) - kwargs = module._parameters[name].__dict__ - kwargs["requires_grad"] = param.requires_grad - module._parameters[name] = param_cls( - module._parameters[name].to(device), **kwargs - ) - - tensor_constructors_to_patch = {} - - def patch_tensor_constructor(fn): - def wrapper(*args, **kwargs): - kwargs["device"] = device - return fn(*args, **kwargs) - - return wrapper - - try: - nn.Module.register_parameter = register_empty_parameter - for torch_function_name in tensor_constructors_to_patch: - setattr( - torch, - torch_function_name, - patch_tensor_constructor(getattr(torch, torch_function_name)), - ) - yield - finally: - nn.Module.register_parameter = old_register_parameter - for ( - torch_function_name, - old_torch_function, - ) in tensor_constructors_to_patch.items(): - setattr(torch, torch_function_name, old_torch_function) - - -class MultiModalProcessingInfo(BaseProcessingInfo): - def get_supported_mm_limits(self): - return {"image": None} - - def get_mm_max_tokens_per_item(self, seq_len, mm_counts): - return {"image": self.get_max_image_tokens()} - - def get_max_image_tokens(self) -> int: - width, height = self.get_max_image_size() - processor = self.get_hf_processor() - multimodal_config = self.ctx.model_config.multimodal_config - mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} - mm_tokens = processor._get_num_multimodal_tokens( - image_sizes=([height, width],), **mm_processor_kwargs - ) - image_tokens = mm_tokens["num_image_tokens"][0] - return image_tokens - - def get_max_image_size(self): - return 10_000, 10_000 # hardcode for arbitrary very large size - - -class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - num_images = mm_counts.get("image", 0) - - processor = self.info.get_hf_processor() - if "gemma3" in processor.__class__.__name__.lower(): - image_token = processor.boi_token - else: - image_token = getattr(processor, "image_token", "") - return image_token * num_images - - def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, - ) -> MultiModalDataDict: - num_images = mm_counts.get("image", 0) - - target_width, target_height = self.info.get_max_image_size() - - image_overrides = mm_options.get("image") if mm_options else None - - return { - "image": self._get_dummy_images( - width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides, - ), - } - - -class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): - def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargsItems, - ): - """ - Given the original multi-modal items for this modality - and HF-processed data, output the updates to perform. - - The information returned by this method is used to update token inputs - which bypass the HF processor. It is also used to update the output of - HF processor if the HF process does not apply prompt updates to text - inputs. - - Moreover, this information is critical to determine the token positions - in order to construct :class:`~vllm-multimodal.input.PlaceholderRange` - for each multi-modal item. - """ - return None - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - # HF Processors always return a mask but vLLM doesn't need it - hf_inputs.pop("attention_mask", None) - num_image_patches = hf_inputs.get("num_image_patches") - mm_fields = { - key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches) - for key in hf_inputs - } - mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes( - "image", num_image_patches - ) - - # Keep these as batched, as they always have batch size as first dim - mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image") - mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image") - mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image") - return mm_fields - - def _get_hf_mm_data( - self, - mm_items: MultiModalDataItems, - ) -> tuple[Mapping[str, object], Mapping[str, object]]: - """ - In contrast to the base class, this method always adds - `return_mm_token_type_ids` to the processor data - """ - processor_data, passthrough_data = super()._get_hf_mm_data(mm_items) - processor_data["return_mm_token_type_ids"] = True - return processor_data, passthrough_data - - def apply( - self, - prompt: Union[str, list[int]], - mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Optional[Mapping[str, object]] = None, - mm_uuids: Optional[MultiModalUUIDDict] = None, - ) -> MultiModalInputs: - """ - Process multi-modal inputs to be used in vLLM. - - Apply HF Processor on prompt text and multi-modal data together, - outputting token IDs and processed tensors. - """ - if tokenization_kwargs is None: - tokenization_kwargs = {} - - mm_items = self._to_mm_items(mm_data) - hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - if not isinstance(prompt, str): - # the prompt is the tokenized ids which is not supported - # by the hf_processor, which is why we would need to decode the ids - # into string - prompt = hf_processor.decode(prompt) - - # Bypass cached processor and always apply to the full set of mm inputs - # NOTE: we can't just set caching=False because base class method - # transforms outputs to `MultiModalKwargs` which is not going to - # work for Transformers. We have a lot of logic tied to - # `mm_tokens_per_modality` below - prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm( - prompt_text=prompt, - mm_items=mm_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - tokenization_kwargs=tokenization_kwargs, - ) - - # For gemma3 we check `token_type_ids` as the key - token_type_key = ( - "mm_token_type_ids" - if "mm_token_type_ids" in processed_data - else "token_type_ids" - ) - mm_token_type_ids = processed_data.pop(token_type_key) - - # We can infer vLLM style placeholder from token type ids, if we split - # it for each input `mm_data`. - mm_positions = torch.where(mm_token_type_ids == 1)[1] - images = mm_items.get_items("image", ImageProcessorItems) - multimodal_config = self.info.ctx.model_config.multimodal_config - mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} - image_sizes = [] - for item_idx in range(len(images)): - image_size = images.get_image_size(item_idx) - image_sizes.append((image_size.height, image_size.width)) - - mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens( - image_sizes=image_sizes, **mm_processor_kwargs - ) - - mm_placeholders = {} - split_sizes = mm_tokens_per_modality["num_image_tokens"] - if split_sizes: - chunked_mm_positions = torch.split(mm_positions, split_sizes) - mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()] - chunked_mm_tokens = torch.split(mm_tokens, split_sizes) - ranges = [ - PlaceholderRange( - offset=positions[0].item(), - length=positions.shape[0], - is_embed=(mm_tokens == hf_processor.image_token_id).bool(), - ) - for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens) - ] - mm_placeholders = {"image": ranges} - - processed_data["num_image_patches"] = torch.tensor( - mm_tokens_per_modality["num_image_patches"] - ) - mm_kwargs = MultiModalKwargsItems.from_hf_inputs( - processed_data, - self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), - ) - - # Use overrides if provided; fallback to data-dependent hashing. - mm_hashes = self._hash_mm_items( - mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids - ) - - return MultiModalInputs( - type="multimodal", - prompt_token_ids=prompt_ids, - mm_kwargs=mm_kwargs, - mm_hashes=mm_hashes, - mm_placeholders=mm_placeholders, - ) - - -class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): - embedding_padding_modules = ["lm_head"] - embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - logger.info("Using Transformers backend.") - - self.config: PretrainedConfig = vllm_config.model_config.hf_config - self.text_config: PretrainedConfig = self.config.get_text_config() - self.cache_config: CacheConfig = vllm_config.cache_config - self.device_config: DeviceConfig = vllm_config.device_config - self.model_config: ModelConfig = vllm_config.model_config - self.parallel_config: ParallelConfig = vllm_config.parallel_config - self.quant_config: Optional[QuantizationConfig] = vllm_config.quant_config - - self.pp_group = get_pp_group() - self.tp_group = get_tp_group() - - # Weights to skip in `self.load_weights` - self.skip_prefixes: list[str] = [] - """Skip loading weights whose qualname starts with these prefixes.""" - self.skip_substrs: list[str] = [] - """Skip loading weights whose qualname contains these substrings.""" - self.ignore_unexpected_prefixes: list[str] = [] - """Ignore unexpected weights whose qualname starts with these prefixes. - """ - self.ignore_unexpected_suffixes: list[str] = [] - """Ignore unexpected weights whose qualname ends with these suffixes.""" - - if self.quant_config: - quant_method_name = self.quant_config.get_name() - # Check for unsupported quantization methods. - if quant_method_name == "mxfp4": - raise NotImplementedError( - "Transformers backend does not support MXFP4 quantization yet." - ) - # Skip loading extra bias for GPTQ models. - if "gptq" in quant_method_name: - self.ignore_unexpected_suffixes.append(".bias") - - # Set correct attn and init on "meta" to delay allocating GPU tensors - self.text_config._attn_implementation = "vllm" - with init_on_device_without_buffers("meta"): - self.model: PreTrainedModel = AutoModel.from_config( - self.config, - torch_dtype=self.model_config.dtype, - trust_remote_code=self.model_config.trust_remote_code, - ) - - # Remove layers not on this pipeline parallel rank - self.pipeline_parallel() - # Substitute remaining layers with vLLM's layers as needed - self.recursive_replace() - # Create attention instances for KV cache allocation - self.attention_instances = self.create_attention_instances() - - # Input embeddings - input_embeddings = self.model.get_input_embeddings() - if not isinstance(input_embeddings, PPMissingLayer): - # Some models use embedding scales - self.embed_scale = getattr(input_embeddings, "embed_scale", None) - names = ("embedding_size", "hidden_size") - embedding_dim = getattr_iter(self.text_config, names, None) - assert embedding_dim is not None - self.model.set_input_embeddings( - VocabParallelEmbedding( - self.text_config.vocab_size, - embedding_dim=embedding_dim, - org_num_embeddings=self.text_config.vocab_size, - quant_config=self.quant_config, - ) - ) - - # Initialize any parameters that have not had their modules replaced - self.init_parameters(self.model) - - # Pipeline parallel intermediate tensors - self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( - ["hidden_states"], self.text_config.hidden_size - ) - - def pipeline_parallel(self): - """ - Apply the model's pipeline parallelization plan. - """ - if self.pp_group.world_size <= 1: - return - - if not self.model.supports_pp_plan: - tip = get_feature_request_tip( - self.model_config.model, self.model_config.trust_remote_code - ) - raise ValueError( - f"{type(self.model)} does not support pipeline parallel. {tip}" - ) - - module_lists = [] - module_list_idx = None - pp_plan = list(self.model._pp_plan.keys()) - for i, name in enumerate(pp_plan): - if isinstance(getattr(self.model, name), nn.ModuleList): - module_lists.append(name) - module_list_idx = i - - if len(module_lists) > 1: - raise ValueError( - "Pipeline parallel of models with multiple `ModuleList`s " - "in the base model are not supported yet!" - ) - if module_list_idx is None: - raise ValueError(f"Could not find `ModuleList` in {type(self.model)}") - - # Layers before module list - for name in pp_plan[:module_list_idx]: - if self.pp_group.is_first_rank or ( - self.text_config.tie_word_embeddings and self.pp_group.is_last_rank - ): - continue - setattr(self.model, name, PPMissingLayer()) - - # Module list - start_layer, end_layer = get_pp_indices( - self.text_config.num_hidden_layers, - self.pp_group.rank_in_group, - self.pp_group.world_size, - ) - layers_name = pp_plan[module_list_idx] - layers = getattr(self.model, layers_name) - for i in range(len(layers)): - if start_layer <= i and i < end_layer: - continue - layers[i] = PPMissingLayer() - - # Layers after module list - for name in pp_plan[module_list_idx + 1 :]: - # Modules that should be on last rank - if not self.pp_group.is_last_rank: - setattr(self.model, name, PPMissingLayer()) - - def recursive_replace(self): - """Recursively replace modules in the model as needed. - - Currently, this replaces: - - - `nn.Linear` with vLLM's tensor parallel linear classes - - `*RMSNorm` with vLLM's `RMSNorm` - """ - tp_plan = self.model.tp_plan - - if not tp_plan and self.tp_group.world_size > 1: - tip = get_feature_request_tip( - self.model_config.model, self.model_config.trust_remote_code - ) - raise ValueError( - f"{type(self.model)} does not support tensor parallel. {tip}" - ) - - # Prefix the patterns because we always start from `self.model` - tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()} - - def _recursive_replace(module: nn.Module, prefix: str): - for child_name, child_module in module.named_children(): - new_module = child_module - qual_name = maybe_prefix(prefix, child_name) - if isinstance(child_module, nn.Linear): - generator = (p for p in tp_plan if re.match(p, qual_name)) - pattern = next(generator, None) - # Some weight loaders expect all linear layers to inherit - # LinearBase, so we set a default style which causes any - # unspecified layers to be replaced with ReplicatedLinear - style = tp_plan.get(pattern, "replicate") - new_module = replace_linear_class( - child_module, style, self.quant_config, prefix=qual_name - ) - # TODO(hmellor): Enable RMSNorm replacement once we have a way - # to choose RMSNorm vs GemmaRMSNorm - # elif child_module.__class__.__name__.endswith("RMSNorm"): - # new_module = replace_rms_norm_class( - # child_module, self.config.hidden_size) - else: - _recursive_replace(child_module, prefix=qual_name) - - if new_module is not child_module: - setattr(module, child_name, new_module) - log_replacement(qual_name, child_module, new_module) - - _recursive_replace(self.model, prefix="model") - - def create_attention_instances( - self, attn_type: AttentionType = AttentionType.DECODER - ) -> dict[int, Attention]: - """ - Create `Attention` instances to inform KV cache allocation. - """ - num_heads = self.model_config.get_num_attention_heads(self.parallel_config) - head_size = self.model_config.get_head_size() - num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) - logits_soft_cap = getattr(self.text_config, "attn_logit_softcapping", None) - start, end = get_pp_indices( - self.text_config.num_hidden_layers, - self.pp_group.rank_in_group, - self.pp_group.world_size, - ) - - attention_instances = {} - for i in range(start, end): - # Handle interleaved sliding window attention - per_layer_sliding_window = None - if ( - hasattr(self.config, "layer_types") - and self.config.layer_types[i] == "sliding_attention" - ): - per_layer_sliding_window = self.config.sliding_window - - attention_instances[i] = Attention( - num_heads=num_heads, - head_size=head_size, - # NOTE: We use Llama scale as default, if it's set by - # Transformers, it's updated in vllm_flash_attention_forward - scale=head_size**-0.5, - num_kv_heads=num_kv_heads, - cache_config=self.cache_config, - quant_config=self.quant_config, - logits_soft_cap=logits_soft_cap, - per_layer_sliding_window=per_layer_sliding_window, - prefix=f"{i}.attn", - attn_type=attn_type, - ) - return attention_instances - - def init_parameters(self, module: nn.Module, dtype: Optional[torch.dtype] = None): - """ - If a `parameter` is on the `meta` device, then its parent - `module` is the original module created by: - - ```python - with torch.device("meta"): - self.model: PreTrainedModel = AutoModel.from_config(...) - ``` - """ - - def _init_parameters(module: nn.Module, dtype: Optional[torch.dtype]): - for name, param in module.named_parameters(recurse=False): - if param.device == torch.device("meta"): - new_param = nn.Parameter( - torch.empty_like( - param.data, - dtype=dtype or self.model_config.dtype, - device=self.device_config.device, - ) - ) - setattr(module, name, new_param) - for child in module.children(): - _init_parameters(child, dtype) - - _init_parameters(module, dtype) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: - if not self.pp_group.is_first_rank: - assert intermediate_tensors is not None - input_ids = None - inputs_embeds = intermediate_tensors["hidden_states"] - - if input_ids is not None: - input_ids = input_ids[None, ...] - if inputs_embeds is not None: - inputs_embeds = inputs_embeds[None, ...] - - if self.model_config.uses_mrope: - position_ids = positions[:, None] - else: - position_ids = positions[None, ...] - - hidden_states = self.model( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - use_cache=False, - position_ids=position_ids, - attention_instances=self.attention_instances, - return_dict=False, - **kwargs, - )[0][0, ...] # we remove batch dimension for now - - if not self.pp_group.is_last_rank: - return IntermediateTensors({"hidden_states": hidden_states}) - - return hidden_states - - def load_weights( - self, - weights: Iterable[tuple[str, torch.Tensor]], - ) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=self.skip_prefixes, - skip_substrs=self.skip_substrs, - ignore_unexpected_prefixes=self.ignore_unexpected_prefixes, - ignore_unexpected_suffixes=self.ignore_unexpected_suffixes, - ) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - - def check_version(self, min_version: str, feature: str): - installed = Version(transformers.__version__) - required = Version(min_version) - if installed < required: - raise ImportError( - f"Transformers backend requires transformers>={required} " - f"for {feature}, but got {installed}" - ) - - -@support_torch_compile(enable_if=can_enable_torch_compile) -class TransformersForCausalLM(TransformersBase): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - - # Tell `TransformersBase.load_weights` to skip - # `lm_head` if the model has tied word embeddings - if self.text_config.tie_word_embeddings: - self.skip_prefixes.append("lm_head.") - - if self.pp_group.is_last_rank: - self.unpadded_vocab_size = self.text_config.vocab_size - self.lm_head = ParallelLMHead( - self.text_config.vocab_size, - self.text_config.hidden_size, - quant_config=self.quant_config, - prefix=maybe_prefix(prefix, "lm_head"), - ) - if self.text_config.tie_word_embeddings: - self.lm_head = self.lm_head.tie_weights( - self.model.get_input_embeddings() - ) - - logit_scale = getattr(self.text_config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale - ) - else: - self.lm_head = PPMissingLayer() - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings()(input_ids) - if self.embed_scale is not None: - inputs_embeds *= self.embed_scale - return inputs_embeds - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states) - return logits - - -@MULTIMODAL_REGISTRY.register_processor( - MultiModalProcessor, - info=MultiModalProcessingInfo, - dummy_inputs=MultiModalDummyInputsBuilder, -) -@support_torch_compile( - # set `positions` to last dim to support Qwen-mrope - dynamic_arg_dims={ - "input_ids": 0, - "positions": -1, - "intermediate_tensors": 0, - "inputs_embeds": 0, - }, - enable_if=can_enable_torch_compile, -) -class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): - supports_multimodal_raw_input_only = True - merge_by_field_config = True - # Backwards compatibility for prev released models. State dicts back then - # had different formats and cannot be loaded with `AutoModel` mapping as is - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - "language_model.model": "model.language_model", - "text_model.model": "model.text_model", - "vision_tower": "model.vision_tower", - "vqmodel": "model.vqmodel", - "visual": "model.visual", - "vision_model": "model.vision_model", - "vision_embed_tokens": "model.vision_embed_tokens", - "image_newline": "model.image_newline", - "multi_modal_projector": "model.multi_modal_projector", - "text_model.lm_head": "lm_head", - "language_model.lm_head": "lm_head", - # Qwen models used "model" as the name for the language model. - # Therefore, we must map each of submodule explicitly to avoid - # conflicts with newer models that use "model.language_model". - "model.embed_tokens": "model.language_model.embed_tokens", - "model.layers": "model.language_model.layers", - "model.norm": "model.language_model.norm", - } - ) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - - self.dtype = vllm_config.model_config.dtype - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: - # Gemma3 and PaliGemma needs `token_type_ids` to work correctly - # Other models will not have `token_type_ids` in kwargs - kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"} - model_output = super().forward( - input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs - ) - return model_output - - def get_language_model(self) -> torch.nn.Module: - """`TransformersForMultimodalLM` does not contain a vLLM language model class. - Therefore, in order to return a language model vLLM class, we use a wrapper to - give `self` the same interface as `TransformersForCausalLM`.""" - - class LanguageModelWrapper(TransformersForCausalLM): - def __init__(self, multimodal_model): - # Don't call super().__init__() to avoid re-initialization - self.__dict__.update(multimodal_model.__dict__) - - model = getattr_iter(self.model, ("language_model", "text_model"), None) - - return LanguageModelWrapper(self) - - def get_multimodal_embeddings(self, **kwargs): - pixel_values: Optional[torch.Tensor] = kwargs.pop("pixel_values", None) - image_embeds: Optional[torch.Tensor] = kwargs.pop("image_embeds", None) - # Model might use `image_patches` instead of `pixel_values` - if pixel_values is None: - pixel_values = kwargs.pop("image_patches", None) - - if image_embeds is not None: - return image_embeds - - if pixel_values is None: - return None - - num_image_patches = kwargs.pop("num_image_patches") - kwargs.pop("token_type_ids", None) # used only in `forward` - if pixel_values is not None: - vision_embeddings = self.model.get_image_features(pixel_values, **kwargs) - - if isinstance(vision_embeddings, torch.Tensor): - if vision_embeddings.ndim == 2: - vision_embeddings = vision_embeddings.unsqueeze(0) - - # Embeddings have to be 2D tensors of length `num_images` - # but transformers returns concat tensors if each patch - # is of different size. We split it back to make vLLM happy - vision_embeddings = torch.split( - vision_embeddings, num_image_patches.flatten().tolist() - ) - vision_embeddings = [ - embed.flatten(start_dim=0, end_dim=-2) - for embed in vision_embeddings - ] - - return vision_embeddings - - get_input_embeddings = SupportsMultiModal.get_input_embeddings diff --git a/vllm/model_executor/models/transformers/__init__.py b/vllm/model_executor/models/transformers/__init__.py new file mode 100644 index 000000000000..365b5eb08893 --- /dev/null +++ b/vllm/model_executor/models/transformers/__init__.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wrapper around `transformers` models""" + +from vllm.compilation.decorators import support_torch_compile +from vllm.model_executor.models.transformers.base import Base +from vllm.model_executor.models.transformers.causal import CausalMixin +from vllm.model_executor.models.transformers.legacy import LegacyMixin +from vllm.model_executor.models.transformers.moe import MoEMixin +from vllm.model_executor.models.transformers.multimodal import ( + DYNAMIC_ARG_DIMS, + MultiModalDummyInputsBuilder, + MultiModalMixin, + MultiModalProcessingInfo, + MultiModalProcessor, +) +from vllm.model_executor.models.transformers.pooling import ( + EmbeddingMixin, + SequenceClassificationMixin, +) +from vllm.model_executor.models.transformers.utils import can_enable_torch_compile +from vllm.multimodal import MULTIMODAL_REGISTRY + + +# Text only models +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersForCausalLM(CausalMixin, Base): ... + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ... + + +# Multimodal models +@MULTIMODAL_REGISTRY.register_processor( + MultiModalProcessor, + info=MultiModalProcessingInfo, + dummy_inputs=MultiModalDummyInputsBuilder, +) +@support_torch_compile( + dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile +) +class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ... + + +@MULTIMODAL_REGISTRY.register_processor( + MultiModalProcessor, + info=MultiModalProcessingInfo, + dummy_inputs=MultiModalDummyInputsBuilder, +) +@support_torch_compile( + dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile +) +class TransformersMultiModalMoEForCausalLM( + MoEMixin, MultiModalMixin, CausalMixin, Base +): ... + + +# Embedding models +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersEmbeddingModel(EmbeddingMixin, LegacyMixin, Base): ... + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ... + + +@MULTIMODAL_REGISTRY.register_processor( + MultiModalProcessor, + info=MultiModalProcessingInfo, + dummy_inputs=MultiModalDummyInputsBuilder, +) +@support_torch_compile( + dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile +) +class TransformersMultiModalEmbeddingModel(EmbeddingMixin, MultiModalMixin, Base): ... + + +# Sequence classification models +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersForSequenceClassification( + SequenceClassificationMixin, LegacyMixin, Base +): ... + + +@support_torch_compile(enable_if=can_enable_torch_compile) +class TransformersMoEForSequenceClassification( + SequenceClassificationMixin, MoEMixin, Base +): ... + + +@MULTIMODAL_REGISTRY.register_processor( + MultiModalProcessor, + info=MultiModalProcessingInfo, + dummy_inputs=MultiModalDummyInputsBuilder, +) +@support_torch_compile( + dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile +) +class TransformersMultiModalForSequenceClassification( + SequenceClassificationMixin, MultiModalMixin, Base +): ... + + +def __getattr__(name: str): + """Handle imports of non-existent classes with a helpful error message.""" + if name not in globals(): + raise AttributeError( + "The Transformers backend does not currently have a class to handle " + f"the requested model type: {name}. Please open an issue at " + "https://github.com/vllm-project/vllm/issues/new" + ) + return globals()[name] diff --git a/vllm/model_executor/models/transformers/base.py b/vllm/model_executor/models/transformers/base.py new file mode 100644 index 000000000000..41d170c9e139 --- /dev/null +++ b/vllm/model_executor/models/transformers/base.py @@ -0,0 +1,457 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformers backend base class.""" + +from collections.abc import Iterable +from typing import TYPE_CHECKING + +import regex as re +import torch +import transformers +from packaging.version import Version +from torch import nn +from transformers import AutoModel +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + +from vllm.attention import Attention, AttentionType +from vllm.config.utils import getattr_iter +from vllm.distributed import get_pp_group, get_tp_group +from vllm.distributed.utils import get_pp_indices +from vllm.logger import init_logger +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.models.interfaces import ( + SupportsLoRA, + SupportsPP, + SupportsQuant, +) +from vllm.model_executor.models.interfaces_base import VllmModel +from vllm.model_executor.models.transformers.utils import ( + get_feature_request_tip, + init_on_device_without_buffers, + log_replacement, + replace_linear_class, + replace_rms_norm_class, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + PPMissingLayer, + WeightsMapper, + make_empty_intermediate_tensors_factory, + maybe_prefix, +) +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + from vllm.config import VllmConfig +else: + PreTrainedModel = object + +logger = init_logger(__name__) + + +def vllm_flash_attention_forward( + # Transformers args + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor, + # Transformers kwargs + scaling: float | None = None, + # vLLM kwargs + attention_instances: dict[int, Attention] | None = None, + **kwargs, +): + self_attn = attention_instances[module.layer_idx] + if scaling is not None: + self_attn.impl.scale = float(scaling) + hidden = query.shape[-2] + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) + query, key, value = (x.reshape(hidden, -1) for x in (query, key, value)) + return self_attn.forward(query, key, value), None + + +ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward + + +class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): + embedding_padding_modules = ["lm_head"] + embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # Add `model.` prefix for base model checkpoints, + # handling the case where it is already present + "": "model.", + "model.model.": "model.", + # Heads will be adjacent to `model` (pooling included because of adapters) + "model.lm_head.": "lm_head.", + "model.score.": "classifier.", + "model.classifier.": "classifier.", + } + ) + + def __init_subclass__(cls, *args, **kwargs): + """Merge hf_to_vllm_mapper in MRO from most specific to least specific.""" + super().__init_subclass__(*args, **kwargs) + hf_to_vllm_mapper = WeightsMapper() + for base in cls.__mro__: + if base_hf_to_vllm_mapper := getattr(base, "hf_to_vllm_mapper", None): + hf_to_vllm_mapper |= base_hf_to_vllm_mapper + cls.hf_to_vllm_mapper = hf_to_vllm_mapper + + def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): + super().__init__() + logger.info("Using Transformers backend.") + + self.config = vllm_config.model_config.hf_config + self.text_config = self.config.get_text_config() + self.cache_config = vllm_config.cache_config + self.device_config = vllm_config.device_config + self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.quant_config = vllm_config.quant_config + + self.pp_group = get_pp_group() + self.tp_group = get_tp_group() + + # Weights to skip in `self.load_weights` + self.skip_prefixes: list[str] = [] + """Skip loading weights whose qualname starts with these prefixes.""" + self.skip_substrs: list[str] = [] + """Skip loading weights whose qualname contains these substrings.""" + self.ignore_unexpected_prefixes: list[str] = [] + """Ignore unexpected weights whose qualname starts with these prefixes. + """ + self.ignore_unexpected_suffixes: list[str] = [] + """Ignore unexpected weights whose qualname ends with these suffixes.""" + + if self.quant_config: + quant_method_name = self.quant_config.get_name() + # Check for unsupported quantization methods. + if quant_method_name == "mxfp4": + raise NotImplementedError( + "Transformers backend does not support MXFP4 quantization yet." + ) + # Skip loading extra bias for GPTQ models. + if "gptq" in quant_method_name: + self.ignore_unexpected_suffixes.append(".bias") + + # Set correct attn and init on "meta" to delay allocating GPU tensors + self.text_config._attn_implementation = "vllm" + with init_on_device_without_buffers("meta"): + self.model: PreTrainedModel = AutoModel.from_config( + self.config, + dtype=self.model_config.dtype, + trust_remote_code=self.model_config.trust_remote_code, + ) + + # Remove layers not on this pipeline parallel rank + self.pipeline_parallel() + # Substitute remaining layers with vLLM's layers as needed + self.recursive_replace() + # Create attention instances for KV cache allocation + self.attention_instances = self.create_attention_instances() + + # Input embeddings + input_embeddings = self.model.get_input_embeddings() + if not isinstance(input_embeddings, PPMissingLayer): + # Some models scale embeddings inside the input embedding layer + self.embed_scale = getattr(input_embeddings, "embed_scale", None) + names = ("embedding_size", "hidden_size") + embedding_dim = getattr_iter(self.text_config, names, None) + assert embedding_dim is not None + self.model.set_input_embeddings( + VocabParallelEmbedding( + self.text_config.vocab_size, + embedding_dim=embedding_dim, + org_num_embeddings=self.text_config.vocab_size, + quant_config=self.quant_config, + ) + ) + + # Initialize any parameters that have not had their modules replaced + self.init_parameters(self.model) + + # Pipeline parallel intermediate tensors + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], self.text_config.hidden_size + ) + + def pipeline_parallel(self): + """ + Apply the model's pipeline parallelization plan. + """ + if self.pp_group.world_size <= 1: + return + + if not self.model.supports_pp_plan: + tip = get_feature_request_tip( + self.model_config.model, self.model_config.trust_remote_code + ) + raise ValueError( + f"{type(self.model)} does not support pipeline parallel. {tip}" + ) + + module_lists = [] + module_list_idx = None + pp_plan = list(self.model._pp_plan.keys()) + for i, name in enumerate(pp_plan): + if isinstance(getattr(self.model, name), nn.ModuleList): + module_lists.append(name) + module_list_idx = i + + if len(module_lists) > 1: + raise ValueError( + "Pipeline parallel of models with multiple `ModuleList`s " + "in the base model are not supported yet!" + ) + if module_list_idx is None: + raise ValueError(f"Could not find `ModuleList` in {type(self.model)}") + + # Layers before module list + for name in pp_plan[:module_list_idx]: + if self.pp_group.is_first_rank or ( + self.text_config.tie_word_embeddings and self.pp_group.is_last_rank + ): + continue + setattr(self.model, name, PPMissingLayer()) + + # Module list + start_layer, end_layer = get_pp_indices( + self.text_config.num_hidden_layers, + self.pp_group.rank_in_group, + self.pp_group.world_size, + ) + layers_name = pp_plan[module_list_idx] + layers = getattr(self.model, layers_name) + for i in range(len(layers)): + if start_layer <= i and i < end_layer: + continue + layers[i] = PPMissingLayer() + + # Layers after module list + for name in pp_plan[module_list_idx + 1 :]: + # Modules that should be on last rank + if not self.pp_group.is_last_rank: + setattr(self.model, name, PPMissingLayer()) + + def recursive_replace(self): + """Recursively replace modules in the model as needed. + + Currently, this replaces: + + - `nn.Linear` with vLLM's tensor parallel linear classes + - `*RMSNorm` with vLLM's `RMSNorm` + """ + tp_plan = self.model.tp_plan + + if not tp_plan and self.tp_group.world_size > 1: + tip = get_feature_request_tip( + self.model_config.model, self.model_config.trust_remote_code + ) + raise ValueError( + f"{type(self.model)} does not support tensor parallel. {tip}" + ) + + # Prefix the patterns because we always start from `self.model` + tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()} + + def _recursive_replace(module: nn.Module, prefix: str): + for child_name, child_module in module.named_children(): + new_module = child_module + qual_name = maybe_prefix(prefix, child_name) + if isinstance(child_module, nn.Linear): + generator = (p for p in tp_plan if re.match(p, qual_name)) + pattern = next(generator, None) + # Some weight loaders expect all linear layers to inherit + # LinearBase, so we set a default style which causes any + # unspecified layers to be replaced with ReplicatedLinear + style = tp_plan.get(pattern, "replicate") + new_module = replace_linear_class( + child_module, style, self.quant_config, prefix=qual_name + ) + elif child_module.__class__.__name__.endswith("RMSNorm"): + new_module = replace_rms_norm_class( + child_module, self.text_config.hidden_size + ) + else: + _recursive_replace(child_module, prefix=qual_name) + + if new_module is not child_module: + setattr(module, child_name, new_module) + log_replacement(qual_name, child_module, new_module) + + _recursive_replace(self.model, prefix="model") + + def create_attention_instances(self) -> dict[int, Attention]: + """ + Create `Attention` instances to inform KV cache allocation. + """ + text_config = self.text_config + + num_heads = self.model_config.get_num_attention_heads(self.parallel_config) + head_size = self.model_config.get_head_size() + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + logits_soft_cap = getattr(text_config, "attn_logit_softcapping", None) + + # In encoder models, the attention layers will have `is_causal=False` + is_encoder = lambda module: not getattr(module, "is_causal", True) + has_encoder = lambda model: any(is_encoder(m) for m in model.modules()) + is_multimodal = lambda config: config != config.get_text_config() + # vLLM does not support encoder-decoder models, so if any encoder layer is + # found in a text only model, we assume the whole model is an encoder model + if has_encoder(self.model) and not is_multimodal(self.config): + self.check_version("4.57.0.dev0", "encoder models support") + attn_type = AttentionType.ENCODER_ONLY + else: + attn_type = AttentionType.DECODER + + pp_rank = self.pp_group.rank_in_group + pp_size = self.pp_group.world_size + start, end = get_pp_indices(text_config.num_hidden_layers, pp_rank, pp_size) + + attention_instances = {} + for i in range(start, end): + # Handle interleaved sliding window attention + per_layer_sliding_window = None + if ( + hasattr(self.config, "layer_types") + and self.config.layer_types[i] == "sliding_attention" + ): + per_layer_sliding_window = self.config.sliding_window + + attention_instances[i] = Attention( + num_heads=num_heads, + head_size=head_size, + # NOTE: We use Llama scale as default, if it's set by + # Transformers, it's updated in vllm_flash_attention_forward + scale=head_size**-0.5, + num_kv_heads=num_kv_heads, + cache_config=self.cache_config, + quant_config=self.quant_config, + logits_soft_cap=logits_soft_cap, + per_layer_sliding_window=per_layer_sliding_window, + prefix=f"{i}.attn", + attn_type=attn_type, + ) + return attention_instances + + def init_parameters(self, module: nn.Module, dtype: torch.dtype | None = None): + """ + If a `parameter` is on the `meta` device, then its parent + `module` is the original module created by: + + ```python + with torch.device("meta"): + self.model: "PreTrainedModel" = AutoModel.from_config(...) + ``` + """ + + def _init_parameters(module: nn.Module, dtype: torch.dtype | None): + for name, param in module.named_parameters(recurse=False): + if param.device == torch.device("meta"): + new_param = nn.Parameter( + torch.empty_like( + param.data, + dtype=dtype or self.model_config.dtype, + device=self.device_config.device, + ) + ) + setattr(module, name, new_param) + for child in module.children(): + _init_parameters(child, dtype) + + _init_parameters(module, dtype) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings()(input_ids) + if self.embed_scale is not None: + inputs_embeds *= self.embed_scale + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor | IntermediateTensors: + if not self.pp_group.is_first_rank: + assert intermediate_tensors is not None + input_ids = None + inputs_embeds = intermediate_tensors["hidden_states"] + + if input_ids is not None: + input_ids = input_ids[None, ...] + if inputs_embeds is not None: + inputs_embeds = inputs_embeds[None, ...] + + # If the model scales embeddings inside the input embedding layer we must + # ensure they are scaled here since VocabParallelEmbedding will not do it + if ( + self.embed_scale is not None + and input_ids is not None + and inputs_embeds is None + ): + inputs_embeds = self.get_input_embeddings(input_ids) + input_ids = None + + if self.model_config.uses_mrope: + position_ids = positions[:, None] + else: + position_ids = positions[None, ...] + + hidden_states = self.model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + use_cache=False, + position_ids=position_ids, + attention_instances=self.attention_instances, + return_dict=False, + **kwargs, + )[0][0, ...] # we remove batch dimension for now + + if not self.pp_group.is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + return hidden_states + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=self.skip_prefixes, + skip_substrs=self.skip_substrs, + ignore_unexpected_prefixes=self.ignore_unexpected_prefixes, + ignore_unexpected_suffixes=self.ignore_unexpected_suffixes, + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + @staticmethod + def check_version(min_version: str, feature: str): + installed = Version(transformers.__version__) + required = Version(min_version) + if installed < required: + raise ImportError( + f"Transformers backend requires transformers>={required} " + f"for {feature}, but got {installed}" + ) diff --git a/vllm/model_executor/models/transformers/causal.py b/vllm/model_executor/models/transformers/causal.py new file mode 100644 index 000000000000..7f7b15a5675a --- /dev/null +++ b/vllm/model_executor/models/transformers/causal.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformers backend mixin for causal language models.""" + +from typing import TYPE_CHECKING + +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.interfaces_base import VllmModelForTextGeneration +from vllm.model_executor.models.utils import PPMissingLayer, maybe_prefix + +if TYPE_CHECKING: + import torch + + from vllm.config import VllmConfig + + +class CausalMixin(VllmModelForTextGeneration): + def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): + # Skip VllmModelForTextGeneration.__init__ and call the next class in MRO + super(VllmModelForTextGeneration, self).__init__( + vllm_config=vllm_config, prefix=prefix + ) + + # Tell `Base.load_weights` to skip + # `lm_head` if the model has tied word embeddings + if self.text_config.tie_word_embeddings: + self.skip_prefixes.append("lm_head.") + + if self.pp_group.is_last_rank: + self.unpadded_vocab_size = self.text_config.vocab_size + self.lm_head = ParallelLMHead( + self.text_config.vocab_size, + self.text_config.hidden_size, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if self.text_config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.get_input_embeddings() + ) + + logit_scale = getattr(self.text_config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale + ) + else: + self.lm_head = PPMissingLayer() + + def compute_logits(self, hidden_states: "torch.Tensor") -> "torch.Tensor | None": + logits = self.logits_processor(self.lm_head, hidden_states) + return logits diff --git a/vllm/model_executor/models/transformers/legacy.py b/vllm/model_executor/models/transformers/legacy.py new file mode 100644 index 000000000000..a453870a2687 --- /dev/null +++ b/vllm/model_executor/models/transformers/legacy.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformers backend mixin for legacy models.""" + +from typing import TYPE_CHECKING + +import torch + +from vllm.model_executor.models.utils import WeightsMapper +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + +class LegacyMixin: + hf_to_vllm_mapper = WeightsMapper( + # These are applied in order, so the order matters! + orig_to_new_prefix={ + # Handle BERT-like models + "roberta": "model", + "bert": "model", + }, + orig_to_new_suffix={ + # Replace legacy suffixes used for norms + ".gamma": ".weight", + ".beta": ".bias", + }, + ) + + def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + # Skip unsupported/unwanted output embeddings layers + self.skip_prefixes.extend( + [ + "model.lm_head.", + "model.predictions.", + "model.qa_outputs.", + "model.embeddings_project.", + "model.discriminator_predictions.", + ] + ) + + # Some encoder models have the position_ids buffer in the checkpoint. + # vLLM will always pass position_ids as an argument, so we skip loading + # the buffer if it exists + self.skip_substrs.append("position_ids") + + # Some encoder models have the bias of the final classifier layer + # in the checkpoint. vLLM does not use this bias, so we skip loading + # it if it exists + self.skip_substrs.append("score.bias") + + # roberta-like models an extra padding in positions. + # FIXME(Isotr0py): This is quite hacky for roberta edge case, + # we should find a better way to handle this. + self.is_roberta = "roberta" in self.text_config.model_type + self.padding_idx = self.text_config.pad_token_id + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if self.is_roberta: + # RoBERTa-specific positions padding + positions += self.padding_idx + 1 + return super().forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers/moe.py similarity index 90% rename from vllm/model_executor/models/transformers_moe.py rename to vllm/model_executor/models/transformers/moe.py index 5267e447902f..5de786f99580 100644 --- a/vllm/model_executor/models/transformers_moe.py +++ b/vllm/model_executor/models/transformers/moe.py @@ -14,31 +14,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Wrapper around `transformers` MoE models.""" +"""Transformers backend mixin for Mixture of Experts (MoE) models.""" -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn as nn -from vllm.compilation.decorators import support_torch_compile from vllm.config.utils import getattr_iter from vllm.distributed import get_dp_group, get_ep_group from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.models.interfaces import MixtureOfExperts +from vllm.model_executor.models.utils import maybe_prefix from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op - -from .interfaces import MixtureOfExperts, SupportsMultiModal -from .transformers import ( - TransformersBase, - TransformersForCausalLM, - TransformersForMultimodalLM, - can_enable_torch_compile, - log_replacement, -) -from .utils import maybe_prefix +from vllm.utils.torch_utils import direct_register_custom_op + +from .utils import log_replacement + +if TYPE_CHECKING: + from vllm.config import VllmConfig @CustomOp.register("transformers_fused_moe") @@ -117,11 +113,11 @@ def transformers_moe_forward_fake( ) -class TransformersMoEBase(TransformersBase, MixtureOfExperts): - def __init__(self, *, vllm_config, prefix=""): +class MoEMixin(MixtureOfExperts): + def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): self.check_version("4.57.0.dev0", "MoE models support") - self.ep_group = get_ep_group() - super().__init__(vllm_config=vllm_config, prefix=prefix) + # Skip MixtureOfExperts.__init__ and call the next class in MRO + super(MixtureOfExperts, self).__init__(vllm_config=vllm_config, prefix=prefix) def set_eplb_state( self, @@ -242,7 +238,7 @@ def forward(self, *args, **kwargs): num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts # MixtureOfExperts mixin settings - ep_size = self.ep_group.world_size + ep_size = get_ep_group().world_size self.mlp_layers = [] # Used for MixtureOfExperts methods self.expert_weights = [] @@ -316,24 +312,5 @@ def _recursive_replace(module: nn.Module, prefix: str): _recursive_replace(child_module, prefix=qual_name) _recursive_replace(self.model, prefix="model") - # Continue with the replacement of layers in TransformersBase + # Continue with the replacement of layers in Base super().recursive_replace() - - -@support_torch_compile(enable_if=can_enable_torch_compile) -class TransformersMoEForCausalLM(TransformersMoEBase, TransformersForCausalLM): - pass - - -@support_torch_compile( - # set `positions` to last dim to support Qwen-mrope - dynamic_arg_dims={ - "input_ids": 0, - "positions": -1, - "intermediate_tensors": 0, - "inputs_embeds": 0, - }, - enable_if=can_enable_torch_compile, -) -class TransformersMoEForMultimodalLM(TransformersMoEBase, TransformersForMultimodalLM): - get_input_embeddings = SupportsMultiModal.get_input_embeddings diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py new file mode 100644 index 000000000000..10abd8659536 --- /dev/null +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -0,0 +1,396 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformers backend mixin for multi-modal models.""" + +from collections.abc import Mapping +from typing import TYPE_CHECKING + +import torch + +from vllm.config.utils import getattr_iter +from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal +from vllm.model_executor.models.utils import WeightsMapper +from vllm.multimodal import MultiModalKwargsItems +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalInputs, + MultiModalUUIDDict, + PlaceholderRange, +) +from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems +from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from transformers import BatchFeature, PretrainedConfig + + from vllm.config import VllmConfig + from vllm.config.multimodal import BaseDummyOptions + +DYNAMIC_ARG_DIMS = { + "input_ids": 0, + # set `positions` to last dim to support Qwen-mrope + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, +} + + +class MultiModalProcessingInfo(BaseProcessingInfo): + def get_supported_mm_limits(self): + return {"image": None} + + def get_mm_max_tokens_per_item(self, seq_len, mm_counts): + return {"image": self.get_max_image_tokens()} + + def get_max_image_tokens(self) -> int: + width, height = self.get_max_image_size() + processor = self.get_hf_processor() + multimodal_config = self.ctx.model_config.multimodal_config + mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} + mm_tokens = processor._get_num_multimodal_tokens( + image_sizes=([height, width],), **mm_processor_kwargs + ) + image_tokens = mm_tokens["num_image_tokens"][0] + return image_tokens + + def get_max_image_size(self): + return 10_000, 10_000 # hardcode for arbitrary very large size + + +class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + if "gemma3" in processor.__class__.__name__.lower(): + image_token = processor.boi_token + else: + image_token = getattr(processor, "image_token", "") + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, "BaseDummyOptions"] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_max_image_size() + + image_overrides = mm_options.get("image") if mm_options else None + + return { + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + } + + +class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ): + """ + Given the original multi-modal items for this modality + and HF-processed data, output the updates to perform. + + The information returned by this method is used to update token inputs + which bypass the HF processor. It is also used to update the output of + HF processor if the HF process does not apply prompt updates to text + inputs. + + Moreover, this information is critical to determine the token positions + in order to construct :class:`~vllm-multimodal.input.PlaceholderRange` + for each multi-modal item. + """ + return None + + def _get_mm_fields_config( + self, + hf_inputs: "BatchFeature", + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + # HF Processors always return a mask but vLLM doesn't need it + hf_inputs.pop("attention_mask", None) + num_image_patches = hf_inputs.get("num_image_patches") + mm_fields = { + key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches) + for key in hf_inputs + } + mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes( + "image", num_image_patches + ) + + # Keep these as batched, as they always have batch size as first dim + mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image") + mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image") + mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image") + return mm_fields + + def _get_hf_mm_data( + self, + mm_items: MultiModalDataItems, + ) -> tuple[Mapping[str, object], Mapping[str, object]]: + """ + In contrast to the base class, this method always adds + `return_mm_token_type_ids` to the processor data + """ + processor_data, passthrough_data = super()._get_hf_mm_data(mm_items) + processor_data["return_mm_token_type_ids"] = True + return processor_data, passthrough_data + + def apply( + self, + prompt: str | list[int], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object] | None = None, + mm_uuids: MultiModalUUIDDict | None = None, + ) -> MultiModalInputs: + """ + Process multi-modal inputs to be used in vLLM. + + Apply HF Processor on prompt text and multi-modal data together, + outputting token IDs and processed tensors. + """ + if tokenization_kwargs is None: + tokenization_kwargs = {} + + mm_items = self._to_mm_items(mm_data) + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + if not isinstance(prompt, str): + # the prompt is the tokenized ids which is not supported + # by the hf_processor, which is why we would need to decode the ids + # into string + prompt = hf_processor.decode(prompt) + + # Bypass cached processor and always apply to the full set of mm inputs + # NOTE: we can't just set caching=False because base class method + # transforms outputs to `MultiModalKwargs` which is not going to + # work for Transformers. We have a lot of logic tied to + # `mm_tokens_per_modality` below + prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm( + prompt_text=prompt, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + ) + + # For gemma3 we check `token_type_ids` as the key + token_type_key = ( + "mm_token_type_ids" + if "mm_token_type_ids" in processed_data + else "token_type_ids" + ) + mm_token_type_ids = processed_data.pop(token_type_key) + + # We can infer vLLM style placeholder from token type ids, if we split + # it for each input `mm_data`. + mm_positions = torch.where(mm_token_type_ids == 1)[1] + images = mm_items.get_items("image", ImageProcessorItems) + multimodal_config = self.info.ctx.model_config.multimodal_config + mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} + image_sizes = [] + for item_idx in range(len(images)): + image_size = images.get_image_size(item_idx) + image_sizes.append((image_size.height, image_size.width)) + + mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens( + image_sizes=image_sizes, **mm_processor_kwargs + ) + + mm_placeholders = {} + split_sizes = mm_tokens_per_modality["num_image_tokens"] + if split_sizes: + chunked_mm_positions = torch.split(mm_positions, split_sizes) + mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()] + chunked_mm_tokens = torch.split(mm_tokens, split_sizes) + ranges = [ + PlaceholderRange( + offset=positions[0].item(), + length=positions.shape[0], + is_embed=(mm_tokens == hf_processor.image_token_id).bool(), + ) + for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens) + ] + mm_placeholders = {"image": ranges} + + processed_data["num_image_patches"] = torch.tensor( + mm_tokens_per_modality["num_image_patches"] + ) + mm_kwargs = MultiModalKwargsItems.from_hf_inputs( + processed_data, + self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), + ) + + # Use overrides if provided; fallback to data-dependent hashing. + mm_hashes = self._hash_mm_items( + mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids + ) + + return MultiModalInputs( + type="multimodal", + prompt_token_ids=prompt_ids, + mm_kwargs=mm_kwargs, + mm_hashes=mm_hashes, + mm_placeholders=mm_placeholders, + ) + + +class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): + supports_multimodal_raw_input_only = True + merge_by_field_config = True + # Backwards compatibility for prev released models. State dicts back then + # had different formats and cannot be loaded with `AutoModel` mapping as is + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "language_model.model": "model.language_model", + "text_model.model": "model.text_model", + "vision_tower": "model.vision_tower", + "vqmodel": "model.vqmodel", + "visual": "model.visual", + "vision_model": "model.vision_model", + "vision_embed_tokens": "model.vision_embed_tokens", + "image_newline": "model.image_newline", + "multi_modal_projector": "model.multi_modal_projector", + "text_model.lm_head": "lm_head", + "language_model.lm_head": "lm_head", + # Qwen models used "model" as the name for the language model. + # Therefore, we must map each of submodule explicitly to avoid + # conflicts with newer models that use "model.language_model". + "model.embed_tokens": "model.language_model.embed_tokens", + "model.layers": "model.language_model.layers", + "model.norm": "model.language_model.norm", + } + ) + + def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): + # Skip SupportsMRoPE.__init__ and call the next class in MRO + super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + # Gemma3 and PaliGemma needs `token_type_ids` to work correctly + # Other models will not have `token_type_ids` in kwargs + kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"} + model_output = super().forward( + input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs + ) + return model_output + + def get_language_model(self) -> torch.nn.Module: + """Transformers backend multimodal classes do not contain a separate vLLM + language model class. Therefore, in order to return a language model vLLM class, + we use a wrapper to give `self` the same interface as a text model.""" + + # Exclude self and object + bases = self.__class__.mro()[1:-1] + # Keep only classes defined in `vllm.model_executor.models.transformers` + bases = [b for b in bases if ".transformers." in b.__module__] + # Exclude MultiModalMixin itself + bases = [b for b in bases if b is not MultiModalMixin] + + class LanguageModel(*bases): + def __init__(self, multimodal_model): + # Don't call super().__init__() to avoid re-initialization + self.__dict__.update(multimodal_model.__dict__) + + model = getattr_iter(self.model, ("language_model", "text_model"), None) + + return LanguageModel(self) + + def get_multimodal_embeddings(self, **kwargs): + pixel_values: torch.Tensor | None = kwargs.pop("pixel_values", None) + image_embeds: torch.Tensor | None = kwargs.pop("image_embeds", None) + # Model might use `image_patches` instead of `pixel_values` + if pixel_values is None: + pixel_values = kwargs.pop("image_patches", None) + + if image_embeds is not None: + return image_embeds + + if pixel_values is None: + return None + + num_image_patches = kwargs.pop("num_image_patches") + kwargs.pop("token_type_ids", None) # used only in `forward` + if pixel_values is not None: + vision_embeddings = self.model.get_image_features(pixel_values, **kwargs) + + if isinstance(vision_embeddings, torch.Tensor): + if vision_embeddings.ndim == 2: + vision_embeddings = vision_embeddings.unsqueeze(0) + + # Embeddings have to be 2D tensors of length `num_images` + # but transformers returns concat tensors if each patch + # is of different size. We split it back to make vLLM happy + vision_embeddings = torch.split( + vision_embeddings, num_image_patches.flatten().tolist() + ) + vision_embeddings = [ + embed.flatten(start_dim=0, end_dim=-2) + for embed in vision_embeddings + ] + + return vision_embeddings + + def get_mrope_input_positions( + self, + input_tokens: list[int], + hf_config: "PretrainedConfig", + image_grid_thw: list[list[int]] | torch.Tensor | None, + video_grid_thw: list[list[int]] | torch.Tensor | None, + second_per_grid_ts: list[float] | None = None, + context_len: int = 0, + seq_len: int | None = None, + audio_feature_lengths: torch.Tensor | None = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + if any((second_per_grid_ts, audio_feature_lengths, use_audio_in_video)): + raise NotImplementedError("Transformers backend only supports images.") + + if isinstance(image_grid_thw, list): + image_grid_thw = torch.tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = torch.tensor(video_grid_thw) + + mrope_positions, mrope_position_delta = self.model.get_rope_index( + input_ids=torch.tensor(input_tokens).unsqueeze(0), + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + ) + + mrope_positions = mrope_positions[:, 0, context_len:seq_len] + mrope_position_delta = mrope_position_delta[0].item() + + return mrope_positions, mrope_position_delta diff --git a/vllm/model_executor/models/transformers/pooling.py b/vllm/model_executor/models/transformers/pooling.py new file mode 100644 index 000000000000..8117bbac013e --- /dev/null +++ b/vllm/model_executor/models/transformers/pooling.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformers backend mixins for pooling models.""" + +from typing import TYPE_CHECKING + +import torch +from transformers import AutoModelForSequenceClassification + +from vllm.config.utils import getattr_iter +from vllm.model_executor.layers.pooler import ( + ClassifierPooler, + CLSPool, + DispatchPooler, + Pooler, +) +from vllm.model_executor.models.interfaces import SupportsCrossEncoding +from vllm.model_executor.models.interfaces_base import VllmModelForPooling + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + +class EmbeddingMixin(VllmModelForPooling): + default_pooling_type = "CLS" + + def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): + # Skip VllmModelForPooling.__init__ and call the next class in MRO + super(VllmModelForPooling, self).__init__( + vllm_config=vllm_config, prefix=prefix + ) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler( + { + "token_embed": Pooler.for_token_embed(pooler_config), + "embed": Pooler.for_embed(pooler_config), + } + ) + + +class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling): + default_pooling_type = "CLS" + + def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""): + # Skip VllmModelForPooling.__init__ and call the next class in MRO + super(VllmModelForPooling, self).__init__( + vllm_config=vllm_config, prefix=prefix + ) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + # Certain information about the the model and classifier can only be + # inferred from the `ForSequenceClassification` class. Therefore, we + # instantiate it on the "meta" device to avoid allocating GPU memory. + with torch.device("meta"): + seq_cls_model = AutoModelForSequenceClassification.from_config( + self.config, + dtype=self.model_config.dtype, + trust_remote_code=self.model_config.trust_remote_code, + ) + + # When used for sequence classification, some models have their + # pooling layers removed. Make sure this is reflected in vLLM. + for module in seq_cls_model.modules(): + if hasattr(module, "pooler") and module.pooler is None: + self.model.pooler = None + break + + # Unlike `lm_head`, `classifier` is not always `nn.Linear`. + self.classifier = getattr_iter(seq_cls_model, ["classifier", "score"], None) + if self.classifier is None: + raise ValueError( + "Could not find `classifier` or `score` layer in the " + "`AutoModelForSequenceClassification` instance." + ) + self.init_parameters(self.classifier, dtype=self.model_config.head_dtype) + + class ClassifierWithReshape(self.classifier.__class__): + """CLSPool has already been applied in `pooling`. + Add dim to match expected input shape of `classifier.forward`.""" + + def forward(self, *args, **kwargs): + if len(args) > 0: + args = (args[0].unsqueeze(1), *args[1:]) + return super().forward(*args, **kwargs) + + self.classifier.__class__ = ClassifierWithReshape + + self.pooler = DispatchPooler( + { + "token_classify": Pooler.for_token_classify( + pooler_config, classifier=self.classifier + ), + "classify": ClassifierPooler( + pooling=CLSPool(), classifier=self.classifier, act_fn="classify" + ), + "score": ClassifierPooler( + pooling=CLSPool(), classifier=self.classifier, act_fn="score" + ), + } + ) diff --git a/vllm/model_executor/models/transformers/utils.py b/vllm/model_executor/models/transformers/utils.py new file mode 100644 index 000000000000..267a6e06e6bb --- /dev/null +++ b/vllm/model_executor/models/transformers/utils.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformers backend utilities.""" + +from contextlib import contextmanager +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +import torch +from torch import nn + +from vllm.config.utils import getattr_iter +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.model_executor.layers.quantization import QuantizationConfig + + +logger = init_logger(__name__) + + +# Copied from `accelerate` +@contextmanager +def init_on_device_without_buffers(device: torch.device): + """ + A context manager under which models are initialized with all + parameters on the specified device. However buffers are not + initialized on specified device. + + Args: + device (`torch.device`): + Device to initialize all parameters on. + """ + + old_register_parameter = nn.Module.register_parameter + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls( + module._parameters[name].to(device), **kwargs + ) + + tensor_constructors_to_patch = {} + + def patch_tensor_constructor(fn): + def wrapper(*args, **kwargs): + kwargs["device"] = device + return fn(*args, **kwargs) + + return wrapper + + try: + nn.Module.register_parameter = register_empty_parameter + for torch_function_name in tensor_constructors_to_patch: + setattr( + torch, + torch_function_name, + patch_tensor_constructor(getattr(torch, torch_function_name)), + ) + yield + finally: + nn.Module.register_parameter = old_register_parameter + for ( + torch_function_name, + old_torch_function, + ) in tensor_constructors_to_patch.items(): + setattr(torch, torch_function_name, old_torch_function) + + +Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"] + + +def replace_linear_class( + linear: nn.Linear, + style: Style = "replicate", + quant_config: "QuantizationConfig | None" = None, + *, + prefix: str = "", +) -> ColumnParallelLinear | RowParallelLinear | ReplicatedLinear: + """ + Replace nn.Linear with one of vLLM's tensor parallel linear classes. + + Args: + linear: `nn.Linear` to be replaced. + style: Tensor parallel style of the new linear, e.g. "colwise". + quant_config: Quantization config for the new linear. + Returns: + The new linear. + """ + + if not isinstance(style, str): + raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") + + vllm_linear_cls, vllm_linear_kwargs = { + "colwise": (ColumnParallelLinear, {}), + "colwise_rep": (ColumnParallelLinear, {"gather_output": True}), + "rowwise": (RowParallelLinear, {}), + "rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}), + "replicate": (ReplicatedLinear, {}), + }.get(style, (ReplicatedLinear, {})) + + return vllm_linear_cls( + input_size=linear.in_features, + output_size=linear.out_features, + bias=linear.bias is not None, + quant_config=quant_config, + prefix=prefix, + return_bias=False, + **vllm_linear_kwargs, + ) + + +def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm: + """Replace a Transformers RMSNorm with vLLM's RMSNorm. + + This method assumes: + - Weight is stored as `weight`. + - Epsilon is stored as `eps` or `variance_epsilon`. + - `with_scale` indicates whether the layer has a weight (Gemma3n only). + - `var_hidden_size` is only ever used for Intern vision encoder in vLLM + and Transformers doesn't appear to have the same concept. + """ + eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6) + kwargs = {"hidden_size": hidden_size, "eps": eps} + # Update hidden size if weight is available + weight_meta = getattr(rms_norm, "weight", None) + if weight_meta is not None: + kwargs["hidden_size"] = weight_meta.size(0) + # Check if weight is all zeros, which indicates GemmaRMSNorm + # We must create a new instance because rms_norm is on meta + try: + with torch.device("cpu"): + weight_test = getattr(rms_norm.__class__(1), "weight", None) + except Exception: + logger.warning( + "Failed to determine if RMSNorm weight is centered on zero or one. " + "Defaulting to one." + ) + weight_test = None + if weight_test is not None and torch.all(weight_test == 0): + return GemmaRMSNorm(**kwargs) + # Otherwise assume it's a regular RMSNorm + kwargs["has_weight"] = getattr(rms_norm, "with_scale", True) + if weight_meta is not None: + kwargs["dtype"] = weight_meta.dtype + else: + # No weight, fall back to weightless RMSNorm + kwargs["has_weight"] = False + return RMSNorm(**kwargs) + + +def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): + logger.debug("%s: %s -> %s", name, old_module, new_module) + + +def get_feature_request_tip( + model: str, + trust_remote_code: bool, +) -> str: + hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new" + gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose" + url = hf_url if trust_remote_code else gh_url + prefix = f"Please open {url} to request support for this feature. " + if Path(model).exists(): + prefix = "" + doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models" + tip = f"See {doc_url} for instructions on how to add support yourself." + return f"{prefix}{tip}" + + +def can_enable_torch_compile(vllm_config: "VllmConfig") -> bool: + """ + Callable to be passed to `@support_torch_compile`'s `enable_if` argument. + + Defaults to `True` but is disabled in the following situations: + + - The model uses dynamic rope scaling. + """ + text_config = vllm_config.model_config.hf_config.get_text_config() + # Dynamic rope scaling is not compatible with torch.compile + rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {} + return rope_scaling.get("rope_type") != "dynamic" diff --git a/vllm/model_executor/models/transformers_pooling.py b/vllm/model_executor/models/transformers_pooling.py deleted file mode 100644 index 98d2611351c0..000000000000 --- a/vllm/model_executor/models/transformers_pooling.py +++ /dev/null @@ -1,223 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Copyright 2024 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Wrapper around `transformers` models for pooling tasks.""" - -from typing import Optional, Union - -import torch -from transformers import AutoModelForSequenceClassification - -from vllm.attention import Attention, AttentionType -from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import ( - ClassifierPooler, - CLSPool, - DispatchPooler, - Pooler, -) -from vllm.sequence import IntermediateTensors - -from .interfaces_base import VllmModelForPooling -from .transformers import TransformersBase, can_enable_torch_compile -from .transformers_moe import TransformersMoEBase -from .utils import WeightsMapper - - -class TransformersPoolingBase(TransformersBase, VllmModelForPooling): - hf_to_vllm_mapper = WeightsMapper( - # These are applied in order, so the order matters! - orig_to_new_prefix={ - # Handle BERT-like models - "roberta": "model", - "bert": "model", - # Add `model.` prefix for base model checkpoints - "": "model.", - # Remove `model.` prefix if it was already there - "model.model.": "model.", - # Classifier/scoring heads will be adjacent to `model` - "model.score": "classifier", - "model.classifier": "classifier", - }, - orig_to_new_suffix={ - # Replace legacy suffixes used for norms - ".gamma": ".weight", - ".beta": ".bias", - }, - ) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - - # Skip unsupported/unwanted output embeddings layers - self.skip_prefixes.extend( - [ - "model.lm_head.", - "model.predictions.", - "model.qa_outputs.", - "model.embeddings_project.", - "model.discriminator_predictions.", - ] - ) - - # Some encoder models have the position_ids buffer in the checkpoint. - # vLLM will always pass position_ids as an argument, so we skip loading - # the buffer if it exists - self.skip_substrs.append("position_ids") - - # Some encoder models have the bias of the final classifier layer - # in the checkpoint. vLLM does not use this bias, so we skip loading - # it if it exists - self.skip_substrs.append("score.bias") - - # roberta-like models an extra padding in positions. - # FIXME(Isotr0py): This is quite hacky for roberta edge case, - # we should find a better way to handle this. - self.is_roberta = "roberta" in self.text_config.model_type - self.padding_idx = self.text_config.pad_token_id - - def create_attention_instances( - self, attn_type: AttentionType = AttentionType.DECODER - ) -> dict[int, Attention]: - # TODO(hmellor): Better way to detect encoder models - # In encoder models, the attention layers will have `is_causal=False` - is_encoder = lambda m: not getattr(m, "is_causal", True) - # vLLM does not support encoder-decoder models, so if any encoder layer - # is found, we assume the whole model is an encoder model - if any(is_encoder(m) for m in self.model.modules()): - attn_type = AttentionType.ENCODER_ONLY - - # Check minimum transformers version for encoder models support - if attn_type == AttentionType.ENCODER_ONLY: - self.check_version("4.57.0.dev0", "encoder models support") - - return super().create_attention_instances(attn_type) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if self.is_roberta: - # RoBERTa-specific positions padding - positions += self.padding_idx + 1 - return super().forward( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - - -@support_torch_compile(enable_if=can_enable_torch_compile) -class TransformersEmbeddingModel(TransformersPoolingBase): - default_pooling_type = "CLS" - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - - pooler_config = vllm_config.model_config.pooler_config - assert pooler_config is not None - - self.pooler = DispatchPooler( - { - "encode": Pooler.for_encode(pooler_config), - "embed": Pooler.for_embed(pooler_config), - } - ) - - -@support_torch_compile(enable_if=can_enable_torch_compile) -class TransformersForSequenceClassification(TransformersPoolingBase): - default_pooling_type = "CLS" - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - - pooler_config = vllm_config.model_config.pooler_config - assert pooler_config is not None - - # Certain information about the the model and classifier can only be - # inferred from the `ForSequenceClassification` class. Therefore, we - # instantiate it on the "meta" device to avoid allocating GPU memory. - with torch.device("meta"): - seq_cls_model = AutoModelForSequenceClassification.from_config( - self.config, - torch_dtype=self.model_config.dtype, - trust_remote_code=self.model_config.trust_remote_code, - ) - - # When used for sequence classification, some models have their - # pooling layers removed. Make sure this is reflected in vLLM. - for module in seq_cls_model.modules(): - if hasattr(module, "pooler") and module.pooler is None: - self.model.pooler = None - break - if self.model.pooler is not None: - raise ValueError( - "Sequence classification models with pooling layers are not " - "supported yet in the Transformers backend." - ) - - # Unlike `lm_head`, `classifier` is not always `nn.Linear`. - self.classifier = seq_cls_model.classifier - self.init_parameters(self.classifier, dtype=self.model_config.head_dtype) - - class ClassifierWithReshape(self.classifier.__class__): - """CLSPool has already been applied in `pooling`. - Add dim to match expected input shape of `classifier.forward`.""" - - def forward(self, *args, **kwargs): - if len(args) > 0: - args = (args[0].unsqueeze(1), *args[1:]) - return super().forward(*args, **kwargs) - - self.classifier.__class__ = ClassifierWithReshape - - self.pooler = DispatchPooler( - { - "encode": Pooler.for_encode(pooler_config), - "classify": ClassifierPooler( - pooling=CLSPool(), - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config - ), - ), - "score": ClassifierPooler( - pooling=CLSPool(), - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config - ), - ), - } - ) - - -@support_torch_compile(enable_if=can_enable_torch_compile) -class TransformersMoEEmbeddingModel(TransformersMoEBase, TransformersEmbeddingModel): - pass - - -@support_torch_compile(enable_if=can_enable_torch_compile) -class TransformersMoEForSequenceClassification( - TransformersMoEBase, TransformersForSequenceClassification -): - pass diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 1fc34f48401d..95d574fb81d7 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -5,7 +5,7 @@ """PyTorch Ultravox model.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal, TypeAlias import torch from torch import nn @@ -68,7 +68,7 @@ class UltravoxAudioFeatureInputs(TensorSchema): type: Literal["audio_features"] data: Annotated[ - Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]], + torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]], TensorShape("bn", "nmb", "t"), ] lens: Annotated[torch.Tensor, TensorShape("bn")] @@ -92,11 +92,13 @@ class UltravoxAudioEmbeddingInputs(TensorSchema): type: Literal["audio_embeds"] data: Annotated[ - Union[torch.Tensor, list[torch.Tensor]], TensorShape("b", "na", "afs", "hs") + torch.Tensor | list[torch.Tensor], TensorShape("b", "na", "afs", "hs") ] -UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, UltravoxAudioEmbeddingInputs] +UltravoxAudioInputs: TypeAlias = ( + UltravoxAudioFeatureInputs | UltravoxAudioEmbeddingInputs +) class UltravoxProcessingInfo(BaseProcessingInfo): @@ -119,7 +121,7 @@ def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor: assert isinstance(feature_extractor, WhisperFeatureExtractor) return feature_extractor - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"audio": None} @@ -133,7 +135,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: feature_extractor = self.info.get_feature_extractor() @@ -346,7 +348,7 @@ def max_context_length(self): ) def get_attention_mask_by_audio_len( - self, audio_lens: Optional[torch.Tensor], hidden_states: torch.Tensor + self, audio_lens: torch.Tensor | None, hidden_states: torch.Tensor ): """ Create attention mask based on audio lengths to mask out padding tokens @@ -376,7 +378,7 @@ def get_attention_mask_by_audio_len( def forward( self, input_features: torch.Tensor, - audio_lens: Optional[torch.Tensor] = None, + audio_lens: torch.Tensor | None = None, ): expected_seq_length = self.max_context_length if input_features.shape[-1] > expected_seq_length: @@ -431,7 +433,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("audio"): return "<|audio|>" @@ -514,7 +516,7 @@ def _audio_features_to_embeddings( def _parse_and_validate_audio_input( self, **kwargs: object - ) -> Optional[UltravoxAudioInputs]: + ) -> UltravoxAudioInputs | None: audio_features = kwargs.pop("audio_features", None) audio_embeds = kwargs.pop("audio_embeds", None) audio_lens = kwargs.pop("audio_lens", None) @@ -541,7 +543,7 @@ def _parse_and_validate_audio_input( def _process_audio_input( self, audio_input: UltravoxAudioInputs, - ) -> Union[NestedTensors, tuple[torch.Tensor, ...]]: + ) -> NestedTensors | tuple[torch.Tensor, ...]: if audio_input["type"] == "audio_embeds": return audio_input["data"] @@ -587,9 +589,9 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, *, - is_multimodal: Optional[torch.Tensor] = None, + is_multimodal: torch.Tensor | None = None, # Multi-modal token ID may exceed vocab size handle_oov_mm_token: bool = True, ) -> torch.Tensor: @@ -608,10 +610,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: """Run forward pass for Ultravox One key thing to understand is the `input_ids` already accounts for the @@ -651,7 +653,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def pad_and_concat_to_dim3( - features: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]], + features: torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]], ) -> torch.Tensor: """ Pad and concatenate a list of tensors. diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 2a64f6865f12..e86fc23c7d36 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -4,7 +4,7 @@ import itertools from collections.abc import Iterable, Mapping from dataclasses import dataclass, field -from typing import Any, Literal, Optional, Protocol, Union, overload +from typing import Any, Literal, Protocol, overload import torch import torch.nn as nn @@ -22,17 +22,19 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import NestedTensors from vllm.sequence import IntermediateTensors -from vllm.utils import ( - cdiv, - direct_register_custom_op, - get_cuda_view_from_cpu_tensor, +from vllm.utils import cdiv +from vllm.utils.platform_utils import ( is_pin_memory_available, is_uva_available, ) +from vllm.utils.torch_utils import ( + direct_register_custom_op, + get_cuda_view_from_cpu_tensor, +) logger = init_logger(__name__) -WeightsMapping = Mapping[str, Optional[str]] +WeightsMapping = Mapping[str, str | None] """If a key maps to a value of `None`, the corresponding weight is ignored.""" @@ -44,7 +46,15 @@ class WeightsMapper: orig_to_new_prefix: WeightsMapping = field(default_factory=dict) orig_to_new_suffix: WeightsMapping = field(default_factory=dict) - def _map_name(self, key: str) -> Optional[str]: + def __or__(self, other: "WeightsMapper") -> "WeightsMapper": + """Combine two `WeightsMapper`s by merging their mappings.""" + return WeightsMapper( + orig_to_new_substr={**self.orig_to_new_substr, **other.orig_to_new_substr}, + orig_to_new_prefix={**self.orig_to_new_prefix, **other.orig_to_new_prefix}, + orig_to_new_suffix={**self.orig_to_new_suffix, **other.orig_to_new_suffix}, + ) + + def _map_name(self, key: str) -> str | None: for substr, new_key in self.orig_to_new_substr.items(): if substr in key: if new_key is None: @@ -99,13 +109,13 @@ class AutoWeightsLoader: the weights only once. The weight loading logic for individual modules can be overridden - by defining a ``load_weights`` method. + by defining a `load_weights` method. Similarly, the weight loading logic for individual parameters can be - overridden by defining a ``weight_loader`` method. + overridden by defining a `weight_loader` method. Detailed weight loading information can be viewed by setting the - environment variable ``VLLM_LOGGING_LEVEL=DEBUG``. + environment variable `VLLM_LOGGING_LEVEL=DEBUG`. """ # Models trained using early version ColossalAI @@ -120,10 +130,10 @@ def __init__( self, module: nn.Module, *, - skip_prefixes: Optional[list[str]] = None, - skip_substrs: Optional[list[str]] = None, - ignore_unexpected_prefixes: Optional[list[str]] = None, - ignore_unexpected_suffixes: Optional[list[str]] = None, + skip_prefixes: list[str] | None = None, + skip_substrs: list[str] | None = None, + ignore_unexpected_prefixes: list[str] | None = None, + ignore_unexpected_suffixes: list[str] | None = None, ) -> None: super().__init__() @@ -306,7 +316,7 @@ def load_weights( self, weights: Iterable[tuple[str, torch.Tensor]], *, - mapper: Optional[WeightsMapper] = None, + mapper: WeightsMapper | None = None, ) -> set[str]: if mapper is not None: weights = mapper.apply(weights) @@ -323,8 +333,8 @@ def init_vllm_registered_model( vllm_config: VllmConfig, *, prefix: str = "", - hf_config: Optional[PretrainedConfig] = None, - architectures: Optional[list[str]] = None, + hf_config: PretrainedConfig | None = None, + architectures: list[str] | None = None, ) -> nn.Module: """ Helper function to initialize an inner model registered to vLLM, @@ -352,7 +362,7 @@ def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]: ... @overload def flatten_bn( - x: Union[list[torch.Tensor], torch.Tensor], + x: list[torch.Tensor] | torch.Tensor, *, concat: Literal[True], ) -> torch.Tensor: ... @@ -360,21 +370,21 @@ def flatten_bn( @overload def flatten_bn( - x: Union[list[torch.Tensor], torch.Tensor], + x: list[torch.Tensor] | torch.Tensor, *, concat: bool = False, -) -> Union[list[torch.Tensor], torch.Tensor]: ... +) -> list[torch.Tensor] | torch.Tensor: ... def flatten_bn( - x: Union[list[torch.Tensor], torch.Tensor], + x: list[torch.Tensor] | torch.Tensor, *, concat: bool = False, -) -> Union[list[torch.Tensor], torch.Tensor]: +) -> list[torch.Tensor] | torch.Tensor: """ - Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs. + Flatten the `B` and `N` dimensions of batched multimodal inputs. - The input tensor should have shape ``(B, N, ...)```. + The input tensor should have shape `(B, N, ...)`. """ if isinstance(x, torch.Tensor): return x.flatten(0, 1) @@ -410,18 +420,26 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str: return " + ".join(_embedding_count_expression(inner) for inner in embeddings) +def split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]: + ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)] + for num in lst: + index = num // interval + ranges[index].append(num) + return ranges + + def _merge_multimodal_embeddings( inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, is_multimodal: torch.Tensor, ) -> torch.Tensor: """ - Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the - positions in ``inputs_embeds`` corresponding to placeholder tokens in - ``input_ids``. + Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the + positions in `inputs_embeds` corresponding to placeholder tokens in + `input_ids`. Note: - This updates ``inputs_embeds`` in place. + This updates `inputs_embeds` in place. """ if len(multimodal_embeddings) == 0: return inputs_embeds @@ -464,17 +482,17 @@ def merge_multimodal_embeddings( input_ids: torch.Tensor, inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, - placeholder_token_id: Union[int, list[int]], + placeholder_token_id: int | list[int], ) -> torch.Tensor: """ - Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the - positions in ``inputs_embeds`` corresponding to placeholder tokens in - ``input_ids``. + Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the + positions in `inputs_embeds` corresponding to placeholder tokens in + `input_ids`. - ``placeholder_token_id`` can be a list of token ids (e.g, token ids + `placeholder_token_id` can be a list of token ids (e.g, token ids of img_start, img_break, and img_end tokens) when needed: This means - the order of these tokens in the ``input_ids`` MUST MATCH the order of - their embeddings in ``multimodal_embeddings`` since we need to + the order of these tokens in the `input_ids` MUST MATCH the order of + their embeddings in `multimodal_embeddings` since we need to slice-merge instead of individually scattering. For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where @@ -489,7 +507,7 @@ def merge_multimodal_embeddings( input_ids for a correct embedding merge. Note: - This updates ``inputs_embeds`` in place. + This updates `inputs_embeds` in place. """ if isinstance(placeholder_token_id, list): is_multimodal = isin_list(input_ids, placeholder_token_id) @@ -769,13 +787,6 @@ def fast_topk( return torch.topk(values, topk, dim=dim) -def get_model_hidden_size(hf_config: PretrainedConfig) -> int: - if hasattr(hf_config, "hidden_size"): - return hf_config.hidden_size - text_config = hf_config.get_text_config() - return text_config.hidden_size - - # Chunk x along the num_tokens axis for sequence parallelism # NOTE: This is wrapped in a torch custom op to work around the following issue: # The output tensor can have a sequence length 0 at small input sequence lengths diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 74262f8b94a6..b5f6c60514c0 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -4,7 +4,8 @@ import itertools import math from abc import ABC, abstractmethod -from typing import Callable, Final, Generic, Literal, Optional, Protocol, TypeVar, Union +from collections.abc import Callable +from typing import Final, Generic, Literal, Protocol, TypeAlias, TypeVar import torch from transformers import PretrainedConfig @@ -77,14 +78,22 @@ def get_vision_encoder_info(hf_config: VisionLanguageConfig) -> VisionEncoderInf raise NotImplementedError(msg) -def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend: +def get_vit_attn_backend( + head_size: int, + dtype: torch.dtype, + *, + attn_backend_override: _Backend | None = None, +) -> _Backend: """ Get the available attention backend for Vision Transformer. """ + if attn_backend_override is not None: + return attn_backend_override + # Lazy import to avoid circular dependency from vllm.attention.selector import get_env_variable_attn_backend - selected_backend: Optional[_Backend] = get_env_variable_attn_backend() + selected_backend: _Backend | None = get_env_variable_attn_backend() if selected_backend is not None: return selected_backend @@ -93,14 +102,13 @@ def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend: VisionFeatureSelectStrategyStr = Literal["class", "default", "full"] -VisionFeatureSelectStrategy = Union[ - VisionFeatureSelectStrategyStr, - Callable[[torch.Tensor], torch.Tensor], -] +VisionFeatureSelectStrategy: TypeAlias = ( + VisionFeatureSelectStrategyStr | Callable[[torch.Tensor], torch.Tensor] +) def _get_vision_feature_selector( - strategy: Union[VisionFeatureSelectStrategy, str], + strategy: VisionFeatureSelectStrategy | str, ) -> Callable[[torch.Tensor], torch.Tensor]: if callable(strategy): return strategy @@ -121,7 +129,7 @@ def _get_vision_feature_selector( def get_num_selected_vision_tokens( num_vision_tokens: int, - strategy: Union[VisionFeatureSelectStrategy, str], + strategy: VisionFeatureSelectStrategy | str, ) -> int: if callable(strategy): dummy_features = torch.empty(1, num_vision_tokens, 64) # [B, L, D] @@ -141,12 +149,12 @@ def get_num_selected_vision_tokens( def resolve_visual_encoder_outputs( - encoder_outputs: Union[torch.Tensor, list[torch.Tensor]], - post_layer_norm: Optional[torch.nn.LayerNorm], + encoder_outputs: torch.Tensor | list[torch.Tensor], + post_layer_norm: torch.nn.LayerNorm | None, *, - select_layers: Optional[list[int]] = None, - max_possible_layers: Optional[int] = None, - feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, + select_layers: list[int] | None = None, + max_possible_layers: int | None = None, + feature_select_strategy: VisionFeatureSelectStrategy | None = None, ) -> torch.Tensor: """Given the outputs a visual encoder module that may correspond to the output of the last layer, or a list of hidden states to be stacked, @@ -499,3 +507,56 @@ def run_dp_sharded_mrope_vision_model( "Found unassigned embeddings" ) return out_embeddings + + +def get_llm_pos_ids_for_vision( + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: list[int], + grid_hs: torch.Tensor, + grid_ws: torch.Tensor, +) -> torch.Tensor: + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(len(t_index), -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(len(t_index), llm_grid_h, -1) + .flatten() + ) + t_index_tensor = ( + torch.Tensor(t_index) + .to(llm_grid_h.device) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .long() + .flatten() + ) + _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids + + +# Due to a performance regression with Conv3D in PyTorch2.9, we reshape +# Conv3D weights to Linear weights for better performance. +# See: https://github.com/vllm-project/vllm/issues/27406 +# and https://github.com/pytorch/pytorch/issues/166122 +# FIXME(Isotr0py): Revert the PR introduces this workaround +# (https://github.com/vllm-project/vllm/pull/27418), +# once the performance issue is resolved in PyTorch. +def conv3d_to_linear_weight(conv3d_weight: torch.Tensor) -> torch.Tensor: + """ + Reshape Conv3D weight to Linear weight. Only work when kernel_size==stride. + """ + out_channels, in_channels, kt, kh, kw = conv3d_weight.shape + linear_weight = conv3d_weight.reshape(out_channels, in_channels * kt * kh * kw) + return linear_weight diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index f929ba9913ec..cce18984b67e 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -5,19 +5,15 @@ from collections.abc import Iterable, Mapping, Sequence from functools import cached_property from math import ceil -from typing import Literal, Optional, Union, cast +from typing import Literal, cast import numpy as np import regex as re import torch import torch.nn as nn from mistral_common.audio import mel_filter_bank -from mistral_common.protocol.instruct.messages import ( - AudioChunk, - RawAudio, - TextChunk, - UserMessage, -) +from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk +from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.transcription.request import TranscriptionRequest from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder @@ -129,9 +125,9 @@ def get_num_audio_tokens( def __call__( self, - text: Optional[Union[TextInput, list[TextInput]]] = None, - audios: Optional[Union[np.ndarray, list[np.ndarray]]] = None, - return_tensors: Optional[Union[str, TensorType]] = None, + text: TextInput | list[TextInput] | None = None, + audios: np.ndarray | list[np.ndarray] | None = None, + return_tensors: str | TensorType | None = None, **kwargs, ) -> Mapping[str, NestedTensors]: if text is None: @@ -192,7 +188,7 @@ def get_tokenizer(self) -> MistralTokenizer: def get_hf_processor(self) -> VoxtralProcessorAdapter: return VoxtralProcessorAdapter(self.get_tokenizer()) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"audio": 5} # Performance tends to degrade after 5 def get_mm_max_tokens_per_item( @@ -220,7 +216,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) @@ -238,7 +234,7 @@ def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> ProcessorInputs: tokenizer = self.info.get_tokenizer() @@ -307,11 +303,11 @@ def get_replacement(item_idx: int): def _cached_apply_hf_processor( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - mm_uuids: Optional[MultiModalUUIDDict] = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, @@ -390,10 +386,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None @@ -405,7 +401,7 @@ def forward( def get_multimodal_embeddings( self, **kwargs - ) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...], None]: + ) -> list[torch.Tensor] | torch.Tensor | tuple[torch.Tensor, ...] | None: audio_inputs = self._parse_and_validate_audio_arrays(**kwargs) if audio_inputs is None: return None @@ -437,7 +433,7 @@ def get_multimodal_embeddings( def _parse_and_validate_audio_arrays( self, **kwargs: object - ) -> Union[list[torch.Tensor], None]: + ) -> list[torch.Tensor] | None: audio_arrays = kwargs.pop("audio_arrays", None) if audio_arrays is None: return None @@ -454,7 +450,7 @@ def _parse_and_validate_audio_arrays( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) @classmethod @@ -479,10 +475,10 @@ def get_generation_prompt( audio: np.ndarray, model_config: ModelConfig, stt_config: SpeechToTextConfig, - language: Optional[str], + language: str | None, task_type: Literal["transcribe", "translate"], request_prompt: str, - to_language: Optional[str], + to_language: str | None, ) -> PromptType: tokenizer = cached_tokenizer_from_config(model_config) audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless @@ -504,7 +500,7 @@ def get_num_audio_tokens( audio_duration_s: float, stt_config: SpeechToTextConfig, model_config: ModelConfig, - ) -> Optional[int]: + ) -> int | None: """ Map from audio duration to number of audio tokens produced by the ASR model, without running a forward pass. @@ -797,7 +793,7 @@ def prepare_inputs_for_conv( return torch.stack(chunked_features), chunks_per_example def forward( - self, input_features: Union[torch.Tensor, list[torch.Tensor]] + self, input_features: torch.Tensor | list[torch.Tensor] ) -> list[torch.Tensor]: if not isinstance(input_features, list): input_features = [input_features] diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 397556cbbcc4..ccfe1871ef07 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -4,7 +4,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from contextlib import nullcontext -from typing import Annotated, Literal, Optional, Union, cast +from typing import Annotated, Literal, cast import numpy as np import torch @@ -34,7 +34,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( @@ -53,6 +52,7 @@ from vllm.transformers_utils.processor import cached_get_processor from vllm.utils.jsontree import json_map_leaves from vllm.utils.tensor_schema import TensorSchema, TensorShape +from vllm.utils.torch_utils import set_default_torch_dtype from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription from .utils import ( @@ -137,7 +137,7 @@ class WhisperAudioInputs(TensorSchema): """ input_features: Annotated[ - Optional[list[torch.Tensor]], + list[torch.Tensor] | None, TensorShape("b", "nmb", "t"), ] @@ -185,8 +185,8 @@ def __init__( num_heads: int, bias: bool = True, attn_type: AttentionType = AttentionType.DECODER, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -258,7 +258,7 @@ def _init_qkv( self, embed_dim: int, bias: bool = True, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: self.qkv_proj = QKVParallelLinear( @@ -291,8 +291,8 @@ def __init__( embed_dim: int, num_heads: int, bias: bool = True, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__( @@ -309,7 +309,7 @@ def _init_qkv( self, embed_dim: int, bias: bool = True, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: self.q_proj = ColumnParallelLinear( @@ -332,7 +332,7 @@ def _init_qkv( def forward( self, hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor], + encoder_hidden_states: torch.Tensor | None, ): q, _ = self.q_proj(hidden_states) @@ -357,7 +357,7 @@ def __init__( embed_dim: int, ffn_dim: int, act_fn: str, - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() @@ -463,7 +463,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward( self, hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor], + encoder_hidden_states: torch.Tensor | None, ): residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) @@ -521,7 +521,7 @@ def __init__( sinusoids(*self.embed_positions.weight.shape) ) - def forward(self, input_features: Union[torch.Tensor, list[torch.Tensor]]): + def forward(self, input_features: torch.Tensor | list[torch.Tensor]): hidden_states = [] for features in input_features: embeds = nn.functional.gelu(self.conv1(features)) @@ -569,7 +569,7 @@ def forward( self, input_ids, positions: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor], + encoder_hidden_states: torch.Tensor | None, ): inputs_embeds = self.get_input_embeddings(input_ids) positions = self.embed_positions(positions) @@ -600,8 +600,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def forward( self, - input_features: Optional[Union[torch.Tensor, list[torch.Tensor]]], - input_ids: Optional[torch.Tensor], + input_features: torch.Tensor | list[torch.Tensor] | None, + input_ids: torch.Tensor | None, positions: torch.Tensor, ) -> torch.Tensor: encoder_outputs = self.get_encoder_outputs(input_features) @@ -614,8 +614,8 @@ def forward( def get_encoder_outputs( self, - input_features: Optional[Union[torch.Tensor, list[torch.Tensor]]], - ) -> Optional[torch.Tensor]: + input_features: torch.Tensor | list[torch.Tensor] | None, + ) -> torch.Tensor | None: if input_features is None: return None return self.encoder(input_features) @@ -670,7 +670,7 @@ def get_hf_processor(self, **kwargs: object) -> WhisperProcessor: processor_class.tokenizer_class = tokenizer_class return self.ctx.get_hf_processor(processor_class, **kwargs) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"audio": 1} def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor: @@ -693,7 +693,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: feature_extractor = self.info.get_feature_extractor() @@ -721,9 +721,9 @@ def pad_dummy_encoder_prompt(self) -> bool: def create_encoder_prompt( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_data: MultiModalDataDict, - ) -> Union[str, list[int]]: + ) -> str | list[int]: # Strictly speaking, whisper encoder only accept audio features. # We create a dummy encoder prompt here which will be padded to # num_audio_tokens. So that we can create dummy data from this @@ -804,7 +804,7 @@ class WhisperForConditionalGeneration( supported_languages = ISO639_1_SUPPORTED_LANGS @classmethod - def validate_language(cls, language: Optional[str]) -> Optional[str]: + def validate_language(cls, language: str | None) -> str | None: if language is None: # TODO language should be optional and can be guessed. # For now we default to en. See @@ -823,10 +823,10 @@ def get_generation_prompt( audio: np.ndarray, model_config: ModelConfig, # not needed here stt_config: SpeechToTextConfig, - language: Optional[str], + language: str | None, task_type: Literal["transcribe", "translate"], request_prompt: str, - to_language: Optional[str], + to_language: str | None, ) -> PromptType: if language is None: raise ValueError( @@ -849,7 +849,7 @@ def get_generation_prompt( return cast(PromptType, prompt) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("audio"): return None @@ -872,7 +872,7 @@ def get_num_audio_tokens( audio_duration_s: float, stt_config: SpeechToTextConfig, model_config: ModelConfig, - ) -> Optional[int]: + ) -> int | None: processor = cached_get_processor(model_config.model) hop_length = processor.feature_extractor.hop_length assert hop_length is not None @@ -928,9 +928,9 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: def get_input_embeddings( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, *, - is_multimodal: Optional[torch.Tensor] = None, + is_multimodal: torch.Tensor | None = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: # This method just returns the decoder sequence embeddings since diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index b69204d02096..2610aa253b57 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -10,7 +10,7 @@ from collections.abc import Iterable from itertools import cycle -from typing import Any, Optional, Union +from typing import Any import torch from torch import nn @@ -60,8 +60,8 @@ def __init__( self, input_dim: int, rank: int, - output_dim: Union[int, list[int]], - quant_config: Optional[QuantizationConfig] = None, + output_dim: int | list[int], + quant_config: QuantizationConfig | None = None, prefix: str = "", ): """Initialize the attention layer. @@ -106,8 +106,8 @@ def __init__( config: Zamba2Config, bare_block_idx: int, num_hybrid_layers: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: """Initialize the attention layer. @@ -288,7 +288,7 @@ def __init__( config: Zamba2Config, bare_block_idx: int, num_hybrid_layers: dict[int, int], - quant_config: Optional[QuantizationConfig] = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: """Initialize the MLP layer. @@ -386,8 +386,8 @@ def __init__( config: Zamba2Config, bare_block_idx: int, num_hybrid_layers: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: """Initialize the decoder layer. @@ -484,9 +484,9 @@ class Zamba2MambaDecoderLayer(nn.Module): def __init__( self, config: Zamba2Config, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: """Initialize the Mamba decoder layer. @@ -523,9 +523,9 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - transformer_hidden_states: Optional[torch.Tensor] = None, - positions: Optional[torch.Tensor] = None, - original_hidden_states: Optional[torch.Tensor] = None, + transformer_hidden_states: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + original_hidden_states: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass through the Mamba decoder layer. @@ -581,9 +581,9 @@ def __init__( shared_transformer: Zamba2AttentionDecoderLayer, config: Zamba2Config, block_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: """Initialize the hybrid layer. @@ -764,8 +764,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: """Forward pass through the model. Args: @@ -947,7 +947,7 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: Any, ) -> torch.Tensor: """Forward pass through the model. @@ -973,7 +973,7 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """Compute logits for next token prediction. Args: diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 9341665f1bca..d3a91feab64d 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -1,9 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Hashable +from collections.abc import Callable, Hashable from fractions import Fraction -from typing import Callable, Optional, Union from weakref import WeakValueDictionary import torch @@ -36,7 +35,7 @@ class BasevLLMParameter(Parameter): into the parameter when the provided weight loader is called. """ - def __new__(cls, data: Optional[torch.Tensor], **kwargs): + def __new__(cls, data: torch.Tensor | None, **kwargs): return super().__new__(cls, data=data, requires_grad=False) def __init__(self, data: torch.Tensor, weight_loader: Callable): @@ -71,7 +70,7 @@ def weight_loader(self) -> Callable: # NOTE(@ksayers) some models such as mamba_mixer2 override the # weight loader to support custom loading. In the future, model-specific # weight loading should be implemented via Model.load_weights. In the - # meantime, support deleting and overriding `weight_loader`` attribute + # meantime, support deleting and overriding `weight_loader` attribute if self._weight_loader is None: raise AttributeError( f"{self.__class__.__name__} weight_loader attribute has been deleted" @@ -109,7 +108,7 @@ def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): self._assert_and_load(loaded_weight) - def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + def _shard_id_as_int(self, shard_id: str | int) -> int: if isinstance(shard_id, int): return shard_id @@ -290,7 +289,7 @@ def load_column_parallel_weight(self, *args, **kwargs): super().load_row_parallel_weight(*args, **kwargs) def _load_into_shard_id( - self, loaded_weight: torch.Tensor, shard_id: Union[str, int], **kwargs + self, loaded_weight: torch.Tensor, shard_id: str | int, **kwargs ): """ Slice the parameter data based on the shard id for @@ -320,10 +319,10 @@ class PackedColumnParameter(_ColumnvLLMParameter): def __init__( self, - packed_factor: Union[int, Fraction], + packed_factor: int | Fraction, packed_dim: int, - marlin_tile_size: Optional[int] = None, - bitblas_tile_size: Optional[int] = None, + marlin_tile_size: int | None = None, + bitblas_tile_size: int | None = None, **kwargs, ): self._packed_factor = packed_factor @@ -371,10 +370,10 @@ class PackedvLLMParameter(ModelWeightParameter): def __init__( self, - packed_factor: Union[int, Fraction], + packed_factor: int | Fraction, packed_dim: int, - marlin_tile_size: Optional[int] = None, - bitblas_tile_size: Optional[int] = None, + marlin_tile_size: int | None = None, + bitblas_tile_size: int | None = None, **kwargs, ): self._packed_factor = packed_factor @@ -437,7 +436,7 @@ class SharedWeightParameter(BasevLLMParameter): local_tensors: set[torch.Tensor] # dictionary mapping partition indices to associated parameters - partitions: dict[int, Union[ModelWeightParameter, Parameter]] + partitions: dict[int, ModelWeightParameter | Parameter] def __new__(cls, **kwargs): return super().__new__(cls, data=None, **kwargs) @@ -547,7 +546,7 @@ def _fake_weight_loader( self, param: BasevLLMParameter, loaded_weight: torch.Tensor, - loaded_weight_shard_id: Optional[Union[str, int]], + loaded_weight_shard_id: str | int | None, ): raise ValueError( "When loading partition weights of " diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 4abd2625f806..759b809433b1 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -3,10 +3,12 @@ """Utils for model executor.""" import copy -from typing import Any, Optional +from typing import Any import torch +from vllm.utils.torch_utils import is_torch_equal_or_newer + def set_random_seed(seed: int) -> None: from vllm.platforms import current_platform @@ -16,7 +18,7 @@ def set_random_seed(seed: int) -> None: def set_weight_attrs( weight: torch.Tensor, - weight_attrs: Optional[dict[str, Any]], + weight_attrs: dict[str, Any] | None, ): """Set attributes on a weight tensor. @@ -83,3 +85,10 @@ def get_moe_expert_mapping( if child_map is not None: return child_map() return [] + + +def maybe_disable_graph_partition(current_backend: str) -> dict[str, bool]: + if current_backend == "inductor" and is_torch_equal_or_newer("2.9.0.dev"): + return {"graph_partition": False} + else: + return {} diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index 1747caf26cef..78cbcd8e5427 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -12,10 +12,7 @@ import vllm.envs as envs from vllm.distributed.parallel_state import get_dp_group from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts -from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( - compute_aligned_M, - deep_gemm_block_shape, -) +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( @@ -23,7 +20,60 @@ ) from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod -from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous +from vllm.utils.deep_gemm import ( + fp8_gemm_nt, + get_mk_alignment_for_contiguous_layout, + m_grouped_fp8_gemm_nt_contiguous, +) + + +def _generate_optimal_warmup_m_values( + max_tokens: int, n: int, device: torch.device +) -> list[int]: + """ + Generate M values that cover all possible DeepGEMM kernel configurations. + Reference: https://github.com/deepseek-ai/DeepGEMM/blob/79f48ee15a82dd5fad5cd9beaa393c1f755e6b55/csrc/jit_kernels/heuristics/common.hpp + + Args: + max_tokens: Maximum number of tokens to warmup for + n: The actual N dimension from the weight tensor + device: The torch device to get properties from. + """ + + def ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + + # DeepGEMM's possible block sizes + block_ms = [64, 128, 256] + block_ns = list(range(16, min(257, n + 1), 16)) + num_sms = torch.cuda.get_device_properties(device).multi_processor_count + + m_values = set() + + # Always include small cases + m_values.update([1, 2, 4] + [i for i in range(8, 65, 8)]) + + # Collect M values where different wave patterns occur + for block_m in block_ms: + for block_n in block_ns: + if block_n > n: + continue + + # Add key M boundaries for this block combination + for wave in range(1, 11): # Up to 10 waves + # M where this block config transitions to next wave + target_blocks = wave * num_sms + m = target_blocks * block_m // ceil_div(n, block_n) + if 1 <= m <= max_tokens: + m_values.add(m) + + # Add block_m boundaries + for multiple in range(1, max_tokens // block_m + 1): + m = multiple * block_m + if m <= max_tokens: + m_values.add(m) + + return sorted(m_values) def _extract_data_from_linear_base_module( @@ -80,7 +130,7 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: """ Return True if the input module/layer could be processed with DeepGEMM. """ - block_size = deep_gemm_block_shape()[0] + block_size = get_mk_alignment_for_contiguous_layout()[0] if not ( isinstance(module, LinearBase) and isinstance(module.quant_method, Fp8LinearMethod) @@ -90,7 +140,7 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: w, _, block_sizes = _extract_data_from_linear_base_module(module) return ( - block_sizes == deep_gemm_block_shape() + block_sizes == get_mk_alignment_for_contiguous_layout() and w.ndim == 2 and w.shape[0] % block_size == 0 and w.shape[1] % block_size == 0 @@ -106,7 +156,7 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: if ( moe_quant_config is None or moe_quant_config.quant_dtype != torch.float8_e4m3fn - or moe_quant_config.block_shape != deep_gemm_block_shape() + or moe_quant_config.block_shape != get_mk_alignment_for_contiguous_layout() ): return False @@ -127,7 +177,7 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: return n, k = w.size() - block_m = deep_gemm_block_shape()[0] + block_m = get_mk_alignment_for_contiguous_layout()[0] device = w.device a1q = torch.empty((max_tokens, k), device=device, dtype=torch.float8_e4m3fn) @@ -136,14 +186,27 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: ) out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16) - pbar = tqdm(total=max_tokens, desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})") - num_tokens = max_tokens - while num_tokens > 0: + # Use optimal M values only if VLLM_DEEP_GEMM_WARMUP is set to "relax". + # Otherwise warmup all token sizes to avoid JIT compilation in hotpath + if envs.VLLM_DEEP_GEMM_WARMUP == "relax": + m_values = _generate_optimal_warmup_m_values(max_tokens, n, device) + desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [relaxed]" + else: + assert envs.VLLM_DEEP_GEMM_WARMUP == "full", ( + "Expected " + 'VLLM_DEEP_GEMM_WARMUP env to be set to "full" but got ' + f"{envs.VLLM_DEEP_GEMM_WARMUP}" + ) + m_values = list(range(1, max_tokens + 1)) + desc = f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()}) [all tokens]" + + pbar = tqdm(total=len(m_values), desc=desc) + + for num_tokens in m_values: fp8_gemm_nt( (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens] ) pbar.update(1) - num_tokens -= 1 FP8_GEMM_NT_WARMUP_CACHE.add(w.size()) @@ -167,7 +230,7 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts" - block_m = deep_gemm_block_shape()[0] + block_m = get_mk_alignment_for_contiguous_layout()[0] num_experts = w1.size(0) device = w1.device @@ -195,12 +258,16 @@ def _warmup(w: torch.Tensor, w_scale: torch.Tensor): ) out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16) + # Generate M values in block_m increments (already optimized for MoE) + m_values = list(range(block_m, MAX_M + 1, block_m)) + pbar = tqdm( - total=MAX_BLOCKS, - desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})", + total=len(m_values), + desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()}) " + f"[{len(m_values)} values, block_m={block_m}]", ) - num_tokens = MAX_M - while num_tokens > 0: + + for num_tokens in m_values: m_grouped_fp8_gemm_nt_contiguous( (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale), @@ -208,7 +275,6 @@ def _warmup(w: torch.Tensor, w_scale: torch.Tensor): expert_ids[:num_tokens], ) pbar.update(1) - num_tokens = num_tokens - block_m for w, ws in [(w1, w1_scale), (w2, w2_scale)]: if w.size() not in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE: diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 23227065ee95..79d1927d3210 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -11,6 +11,7 @@ import torch import vllm.envs as envs +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup from vllm.platforms import current_platform @@ -24,12 +25,26 @@ logger = init_logger(__name__) +def flashinfer_autotune_supported(vllm_config: VllmConfig) -> bool: + """ + Record known issues with vllm + flashinfer autotune here. Return True if + and only if flashinfer autotune will run through without issues. + """ + return not ( + vllm_config.parallel_config.data_parallel_size > 1 + and ( + envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + ) + ) + + def kernel_warmup(worker: "Worker"): # Deep GEMM warmup do_deep_gemm_warmup = ( envs.VLLM_USE_DEEP_GEMM and is_deep_gemm_supported() - and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP + and envs.VLLM_DEEP_GEMM_WARMUP != "skip" ) if do_deep_gemm_warmup: model = worker.get_model() @@ -37,7 +52,11 @@ def kernel_warmup(worker: "Worker"): deep_gemm_warmup(model, max_tokens) # FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs - if has_flashinfer() and current_platform.has_device_capability(90): + if ( + has_flashinfer() + and current_platform.has_device_capability(90) + and flashinfer_autotune_supported(worker.vllm_config) + ): flashinfer_autotune(worker.model_runner) # FlashInfer attention warmup diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index d81354d9a399..53052ddc6343 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -3,12 +3,12 @@ import base64 from io import BytesIO from pathlib import Path -from typing import Literal, Optional +from typing import Literal import numpy as np import numpy.typing as npt -from vllm.utils import PlaceholderModule +from vllm.utils.import_utils import PlaceholderModule from .base import MediaIO @@ -53,7 +53,7 @@ class AudioResampler: def __init__( self, - target_sr: Optional[float] = None, + target_sr: float | None = None, method: Literal["librosa", "scipy"] = "librosa", ): self.target_sr = target_sr diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index 7febc393157f..c1531cbfdc31 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -5,21 +5,21 @@ from abc import ABC, abstractmethod from collections.abc import Mapping, Sequence from multiprocessing.synchronize import Lock as LockType -from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar, cast import torch -from typing_extensions import TypeAlias, override +from typing_extensions import override +import vllm.envs as envs from vllm.distributed.device_communicators.shm_object_storage import ( MsgpackSerde, SingleWriterShmObjectStorage, SingleWriterShmRingBuffer, ) -from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME from vllm.logger import init_logger -from vllm.utils import GiB_bytes, MiB_bytes -from vllm.utils.cache import LRUCache +from vllm.utils.cache import CacheInfo, LRUCache from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves +from vllm.utils.mem_constants import GiB_bytes, MiB_bytes from .inputs import ( MultiModalBatchedField, @@ -85,14 +85,14 @@ def __init__( self.prompt_updates = prompt_updates -MultiModalCacheValue = Union[ - MultiModalProcessorCacheItem, - MultiModalProcessorCacheItemMetadata, - MultiModalKwargsItems, - MultiModalKwargsItem, - MultiModalKwargs, - Mapping[str, NestedTensors], -] +MultiModalCacheValue: TypeAlias = ( + MultiModalProcessorCacheItem + | MultiModalProcessorCacheItemMetadata + | MultiModalKwargsItems + | MultiModalKwargsItem + | MultiModalKwargs + | Mapping[str, NestedTensors] +) _V = TypeVar("_V", bound=MultiModalCacheValue) @@ -256,13 +256,13 @@ def clear_cache(self) -> None: raise NotImplementedError -MultiModalProcessorCacheInItem: TypeAlias = Optional[ - tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]] -] +MultiModalProcessorCacheInItem: TypeAlias = ( + tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]] | None +) MultiModalProcessorCacheOutItem: TypeAlias = tuple[ - Optional[MultiModalKwargsItem], Sequence["ResolvedPromptUpdate"] + MultiModalKwargsItem | None, Sequence["ResolvedPromptUpdate"] ] @@ -302,6 +302,16 @@ def is_cached(self, mm_hashes: list[str]) -> list[bool]: """ return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes] + @abstractmethod + def make_stats(self, *, delta: bool = False) -> CacheInfo: + """ + Get (and reset) the multi-modal cache stats. + + Returns: + The current multi-modal caching stats. + """ + raise NotImplementedError + class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache): """ @@ -347,6 +357,10 @@ def get_and_update_item( def clear_cache(self) -> None: self._cache.clear() + @override + def make_stats(self, *, delta: bool = False) -> CacheInfo: + return self._cache.stat(delta=delta) + class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache): """ @@ -397,6 +411,10 @@ def get_and_update_item( def clear_cache(self) -> None: self._cache.clear() + @override + def make_stats(self, *, delta: bool = False) -> CacheInfo: + return self._cache.stat(delta=delta) + class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache): """ @@ -418,7 +436,7 @@ def __init__(self, vllm_config: "VllmConfig") -> None: ring_buffer = SingleWriterShmRingBuffer( data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes), - name=VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME, + name=envs.VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME, create=True, # sender is the writer ) self._shm_cache = SingleWriterShmObjectStorage( @@ -430,6 +448,20 @@ def __init__(self, vllm_config: "VllmConfig") -> None: # cache (prompt_updates, modality) for P0 only self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {} + self._hits = 0 + self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) + + def _stat(self, *, delta: bool = False) -> CacheInfo: + info = CacheInfo(hits=self._hits, total=self._total) + + if delta: + info_delta = info - self._last_info + self._last_info = info + info = info_delta + + return info + @override def is_cached_item(self, mm_hash: str) -> bool: return self._shm_cache.is_cached(mm_hash) @@ -441,12 +473,17 @@ def get_and_update_item( mm_hash: str, ) -> MultiModalProcessorCacheOutItem: if self._shm_cache.is_cached(mm_hash): + self._hits += 1 + self._total += 1 + address, monotonic_id = self._shm_cache.get_cached(mm_hash) prompt_updates, modality = self._p0_cache[mm_hash] return self.address_as_item(address, monotonic_id, modality), prompt_updates assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + self._total += 1 + try: address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0]) # Try to remove dangling items if p0 cache is too large. @@ -469,6 +506,14 @@ def clear_cache(self) -> None: self._shm_cache.clear() self._p0_cache.clear() + self._hits = 0 + self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) + + @override + def make_stats(self, *, delta: bool = False) -> CacheInfo: + return self._stat(delta=delta) + def remove_dangling_items(self) -> None: """Remove items that are no longer in the shared memory cache.""" cached_hashes = self._shm_cache.key_index.keys() @@ -530,7 +575,7 @@ def _enable_mm_input_shm_cache(vllm_config: "VllmConfig") -> bool: def processor_cache_from_config( vllm_config: "VllmConfig", mm_registry: "MultiModalRegistry", -) -> Optional[BaseMultiModalProcessorCache]: +) -> BaseMultiModalProcessorCache | None: """Return a `BaseMultiModalProcessorCache`, if enabled.""" model_config = vllm_config.model_config @@ -557,7 +602,7 @@ def processor_only_cache_from_config( class BaseMultiModalReceiverCache( - BaseMultiModalCache[Optional[MultiModalKwargsItem], MultiModalKwargsItem] + BaseMultiModalCache[MultiModalKwargsItem | None, MultiModalKwargsItem] ): """The required interface for caches on P1.""" @@ -595,7 +640,7 @@ def __init__(self, model_config: "ModelConfig") -> None: @override def get_and_update_item( self, - mm_item: Optional[MultiModalKwargsItem], + mm_item: MultiModalKwargsItem | None, mm_hash: str, ) -> MultiModalKwargsItem: if (cached_item := self._cache.get(mm_hash)) is not None: @@ -633,7 +678,7 @@ def __init__( ring_buffer = SingleWriterShmRingBuffer( data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes), - name=VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME, + name=envs.VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME, create=False, # Server is a reader ) self._shm_cache = SingleWriterShmObjectStorage( @@ -647,7 +692,7 @@ def __init__( @override def get_and_update_item( self, - mm_item: Optional[MultiModalKwargsItem], + mm_item: MultiModalKwargsItem | None, mm_hash: str, ) -> MultiModalKwargsItem: assert mm_item is not None, f"Expected an address item for {mm_hash=}" @@ -666,7 +711,7 @@ def clear_cache(self) -> None: def engine_receiver_cache_from_config( vllm_config: "VllmConfig", mm_registry: "MultiModalRegistry", -) -> Optional[BaseMultiModalReceiverCache]: +) -> BaseMultiModalReceiverCache | None: """ This is used in the engine process. Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and @@ -690,7 +735,7 @@ def worker_receiver_cache_from_config( vllm_config: "VllmConfig", mm_registry: "MultiModalRegistry", shared_worker_lock: LockType, -) -> Optional[BaseMultiModalReceiverCache]: +) -> BaseMultiModalReceiverCache | None: """ This is used in the worker process. Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and diff --git a/vllm/multimodal/evs.py b/vllm/multimodal/evs.py index 36518c6bdb55..4a288d2d238c 100644 --- a/vllm/multimodal/evs.py +++ b/vllm/multimodal/evs.py @@ -9,7 +9,6 @@ # license agreement from NVIDIA CORPORATION is strictly prohibited. import typing -from typing import Union import torch @@ -38,7 +37,7 @@ def compute_retained_tokens_count( def compute_retention_mask( video_embeds: torch.Tensor, - video_size_thw: Union[torch.LongTensor, tuple[int, int, int]], + video_size_thw: torch.LongTensor | tuple[int, int, int], spatial_merge_size: int, q: float, ) -> torch.Tensor: diff --git a/vllm/multimodal/hasher.py b/vllm/multimodal/hasher.py index 91d86cd9a189..d0dcbb25fcce 100644 --- a/vllm/multimodal/hasher.py +++ b/vllm/multimodal/hasher.py @@ -4,7 +4,6 @@ import pickle import uuid from collections.abc import Iterable -from typing import Union import numpy as np import torch @@ -18,7 +17,7 @@ class MultiModalHasher: @classmethod - def serialize_item(cls, obj: object) -> Iterable[Union[bytes, memoryview]]: + def serialize_item(cls, obj: object) -> Iterable[bytes | memoryview]: # Simple cases if isinstance(obj, (bytes, memoryview)): return (obj,) @@ -84,7 +83,7 @@ def iter_item_to_bytes( cls, key: str, obj: object, - ) -> Iterable[Union[bytes, memoryview]]: + ) -> Iterable[bytes | memoryview]: # Recursive cases if isinstance(obj, (list, tuple)): for i, elem in enumerate(obj): diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index f50ab1faebba..21e8bef97a78 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -3,7 +3,6 @@ from io import BytesIO from pathlib import Path -from typing import Union import pybase64 import torch @@ -26,7 +25,7 @@ def rescale_image_size( def rgba_to_rgb( image: Image.Image, - background_color: Union[tuple[int, int, int], list[int]] = (255, 255, 255), + background_color: tuple[int, int, int] | list[int] = (255, 255, 255), ) -> Image.Image: """Convert an RGBA image to RGB with filled background color.""" assert image.mode == "RGBA" diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index bec3099a99bc..a05f54191f04 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -7,12 +7,23 @@ from dataclasses import dataclass from functools import partial from itertools import accumulate -from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast, final +from typing import ( + TYPE_CHECKING, + Any, + Literal, + Optional, + TypeAlias, + TypedDict, + Union, + cast, + final, +) import numpy as np -from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated +from typing_extensions import NotRequired, TypeVar, deprecated -from vllm.utils import LazyLoader, full_groupby, is_list_of +from vllm.utils.collection_utils import full_groupby, is_list_of +from vllm.utils.import_utils import LazyLoader from vllm.utils.jsontree import json_map_leaves if TYPE_CHECKING: @@ -85,7 +96,7 @@ these are directly passed to the model without HF processing. """ -ModalityData: TypeAlias = Union[_T, list[Optional[_T]], None] +ModalityData: TypeAlias = _T | list[_T | None] | None """ Either a single data item, or a list of data items. Can only be None if UUID is provided. @@ -117,7 +128,7 @@ class MultiModalDataBuiltins(TypedDict, total=False): [`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins]. """ -MultiModalUUIDDict: TypeAlias = Mapping[str, Union[list[Optional[str]], str]] +MultiModalUUIDDict: TypeAlias = Mapping[str, list[str | None] | str] """ A dictionary containing user-provided UUIDs for items in each modality. If a UUID for an item is not provided, its entry will be `None` and @@ -412,7 +423,7 @@ class MultiModalFlatField(BaseMultiModalField): [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes] """ - slices: Union[Sequence[slice], Sequence[Sequence[slice]]] + slices: Sequence[slice] | Sequence[Sequence[slice]] dim: int = 0 def build_elems( @@ -524,7 +535,7 @@ def batched(modality: str): @staticmethod def flat( modality: str, - slices: Union[Sequence[slice], Sequence[Sequence[slice]]], + slices: Sequence[slice] | Sequence[Sequence[slice]], dim: int = 0, ): """ @@ -729,7 +740,7 @@ def get_data(self) -> dict[str, NestedTensors]: _I = TypeVar( "_I", MultiModalKwargsItem, - Optional[MultiModalKwargsItem], + MultiModalKwargsItem | None, default=MultiModalKwargsItem, ) @@ -818,10 +829,10 @@ def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs": ) -MultiModalKwargsOptionalItems: TypeAlias = Union[ - MultiModalKwargsItems[MultiModalKwargsItem], - MultiModalKwargsItems[Optional[MultiModalKwargsItem]], -] +MultiModalKwargsOptionalItems: TypeAlias = ( + MultiModalKwargsItems[MultiModalKwargsItem] + | MultiModalKwargsItems[MultiModalKwargsItem | None] +) class MultiModalKwargs(UserDict[str, NestedTensors]): diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 8fdc5cf721d0..1ae2c7408a66 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -10,16 +10,17 @@ Generic, Literal, NamedTuple, - Optional, + TypeAlias, + TypeGuard, TypeVar, - Union, ) import numpy as np import torch -from typing_extensions import TypeAlias, TypeGuard, assert_never +from typing_extensions import assert_never -from vllm.utils import LazyLoader, is_list_of +from vllm.utils.collection_utils import is_list_of +from vllm.utils.import_utils import LazyLoader from .audio import AudioResampler from .inputs import ( @@ -111,7 +112,7 @@ def get_passthrough_data(self) -> Mapping[str, object]: class EmbeddingItems( - ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]], torch.Tensor] + ModalityDataItems[torch.Tensor | list[torch.Tensor], torch.Tensor] ): """ Base class for data items that are expressed as a batched embedding tensor, @@ -195,7 +196,7 @@ def get_passthrough_data(self) -> Mapping[str, object]: class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]): - def __init__(self, data: Optional[Sequence[HfAudioItem]]) -> None: + def __init__(self, data: Sequence[HfAudioItem] | None) -> None: if data is None: data = [None] super().__init__(data, "audio") @@ -206,7 +207,7 @@ def get_audio_length(self, item_idx: int) -> int: class AudioEmbeddingItems(EmbeddingItems): - def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: + def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None: super().__init__(data, "audio") @@ -216,7 +217,7 @@ class ImageSize(NamedTuple): class ImageProcessorItems(ProcessorBatchItems[HfImageItem]): - def __init__(self, data: Optional[Sequence[HfImageItem]]) -> None: + def __init__(self, data: Sequence[HfImageItem] | None) -> None: if data is None: data = [None] super().__init__(data, "image") @@ -234,17 +235,15 @@ def get_image_size(self, item_idx: int) -> ImageSize: class ImageEmbeddingItems(EmbeddingItems): - def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: + def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None: super().__init__(data, "image") class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]): def __init__( self, - data: Optional[Sequence[HfVideoItem]], - metadata: Optional[ - Union[dict[str, Any], list[Optional[dict[str, Any]]]] - ] = None, + data: Sequence[HfVideoItem] | None, + metadata: dict[str, Any] | list[dict[str, Any] | None] | None = None, ) -> None: if data is None: data = [None] @@ -267,7 +266,7 @@ def get_frame_size(self, item_idx: int) -> ImageSize: class VideoEmbeddingItems(EmbeddingItems): - def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: + def __init__(self, data: torch.Tensor | list[torch.Tensor]) -> None: super().__init__(data, "video") @@ -306,7 +305,7 @@ def get_all_counts(self) -> Mapping[str, int]: def get_items( self, modality: str, - typ: Union[type[_D], tuple[type[_D], ...]], + typ: type[_D] | tuple[type[_D], ...], ) -> _D: """ Get the data items belonging to a modality, @@ -331,7 +330,7 @@ def get_items( ModalityDataParser: TypeAlias = Callable[ - [ModalityData[Any]], Optional[ModalityDataItems[Any, Any]] + [ModalityData[Any]], ModalityDataItems[Any, Any] | None ] @@ -348,7 +347,7 @@ class MultiModalDataParser: def __init__( self, *, - target_sr: Optional[float] = None, + target_sr: float | None = None, audio_resample_method: Literal["librosa", "scipy"] = "librosa", video_needs_metadata: bool = False, ) -> None: @@ -362,11 +361,11 @@ def __init__( def _is_embeddings( self, data: object - ) -> TypeGuard[Union[torch.Tensor, list[torch.Tensor]]]: + ) -> TypeGuard[torch.Tensor | list[torch.Tensor]]: if isinstance(data, torch.Tensor): return data.ndim == 3 if is_list_of(data, torch.Tensor): - return data[0].ndim == 2 + return data[0].ndim == 2 # type: ignore[index] return False @@ -381,7 +380,7 @@ def _is_empty(self, data: object) -> TypeGuard[None]: def _get_audio_with_sr( self, audio: AudioItem, - ) -> tuple[np.ndarray, Optional[float]]: + ) -> tuple[np.ndarray, float | None]: if isinstance(audio, tuple): return audio if isinstance(audio, list): @@ -396,7 +395,7 @@ def _get_audio_with_sr( def _get_video_with_metadata( self, video: VideoItem, - ) -> tuple[np.ndarray, Optional[dict[str, Any]]]: + ) -> tuple[np.ndarray, dict[str, Any] | None]: if isinstance(video, tuple): return video if isinstance(video, list): @@ -411,7 +410,7 @@ def _get_video_with_metadata( def _parse_audio_data( self, data: ModalityData[AudioItem], - ) -> Optional[ModalityDataItems[Any, Any]]: + ) -> ModalityDataItems[Any, Any] | None: if data is None: return AudioProcessorItems(None) @@ -424,6 +423,7 @@ def _parse_audio_data( if self._is_embeddings(data): return AudioEmbeddingItems(data) + data_items: list[AudioItem] if ( is_list_of(data, float) or isinstance(data, (np.ndarray, torch.Tensor)) @@ -434,7 +434,7 @@ def _parse_audio_data( elif isinstance(data, (np.ndarray, torch.Tensor)): data_items = [elem for elem in data] else: - data_items = data + data_items = data # type: ignore[assignment] new_audios = list[np.ndarray]() for data_item in data_items: @@ -451,7 +451,7 @@ def _parse_audio_data( def _parse_image_data( self, data: ModalityData[ImageItem], - ) -> Optional[ModalityDataItems[Any, Any]]: + ) -> ModalityDataItems[Any, Any] | None: if data is None: return ImageProcessorItems(None) @@ -477,7 +477,7 @@ def _parse_image_data( def _parse_video_data( self, data: ModalityData[VideoItem], - ) -> Optional[ModalityDataItems[Any, Any]]: + ) -> ModalityDataItems[Any, Any] | None: if data is None: return VideoProcessorItems(None) @@ -487,6 +487,7 @@ def _parse_video_data( if self._is_embeddings(data): return VideoEmbeddingItems(data) + data_items: list[VideoItem] if ( is_list_of(data, PILImage.Image) or isinstance(data, (np.ndarray, torch.Tensor)) @@ -498,10 +499,10 @@ def _parse_video_data( elif isinstance(data, tuple) and len(data) == 2: data_items = [data] else: - data_items = data + data_items = data # type: ignore[assignment] - new_videos = list[tuple[np.ndarray, Optional[dict[str, Any]]]]() - metadata_lst: list[Optional[dict[str, Any]]] = [] + new_videos = list[tuple[np.ndarray, dict[str, Any] | None]]() + metadata_lst: list[dict[str, Any] | None] = [] for data_item in data_items: video, metadata = self._get_video_with_metadata(data_item) if self.video_needs_metadata: diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 5c3739e29d10..55132a6036ef 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -12,9 +12,8 @@ Any, Generic, NamedTuple, - Optional, Protocol, - Union, + TypeAlias, cast, overload, ) @@ -26,7 +25,8 @@ from vllm.logger import init_logger from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens -from vllm.utils import flatten_2d_lists, full_groupby, get_allowed_kwarg_only_overrides +from vllm.utils.collection_utils import flatten_2d_lists, full_groupby +from vllm.utils.func_utils import get_allowed_kwarg_only_overrides from vllm.utils.jsontree import JSONTree, json_map_leaves from .hasher import MultiModalHasher @@ -57,12 +57,20 @@ from .cache import BaseMultiModalProcessorCache from .profiling import BaseDummyInputsBuilder +else: + PretrainedConfig = object + BatchFeature = object + ProcessorMixin = object + + ModelConfig = object + + BaseMultiModalProcessorCache = object logger = init_logger(__name__) _S = TypeVar("_S", str, list[int]) -PromptSeq = Union[str, list[int]] +PromptSeq: TypeAlias = str | list[int] """A token sequence (list of token IDs) or text.""" @@ -71,7 +79,7 @@ def _cached_encode( tokenizer: AnyTokenizer, text: str, *, - add_special_tokens: Optional[bool] = None, + add_special_tokens: bool | None = None, ) -> list[int]: return encode_tokens(tokenizer, text, add_special_tokens=add_special_tokens) @@ -81,7 +89,7 @@ def _cached_decode( tokenizer: AnyTokenizer, token_ids: tuple[int, ...], *, - skip_special_tokens: Optional[bool] = None, + skip_special_tokens: bool | None = None, ) -> str: return decode_tokens( tokenizer, list(token_ids), skip_special_tokens=skip_special_tokens @@ -108,7 +116,7 @@ def __call__( tokenizer: AnyTokenizer, prompt: PromptSeq, start_idx: int = 0, - ) -> Optional[int]: ... + ) -> int | None: ... @dataclass @@ -138,7 +146,7 @@ def get_match_index( tokenizer: AnyTokenizer, prompt: PromptSeq, start_idx: int = 0, - ) -> Optional[int]: + ) -> int | None: if start_idx != 0: return None @@ -168,12 +176,12 @@ def end() -> PromptIndex: return PromptIndex(lambda tokenizer, prompt, start_idx=0: len(prompt)) -UpdateTarget = Union[PromptSeq, PromptIndex] +UpdateTarget: TypeAlias = PromptSeq | PromptIndex """ The token sequence or text to update. """ -PromptUpdateTarget = Union[Callable[[int], UpdateTarget], UpdateTarget] +PromptUpdateTarget: TypeAlias = Callable[[int], UpdateTarget] | UpdateTarget """ Given the index of the processed item within [`modality`][vllm.multimodal.processing.PromptUpdate.modality], @@ -191,7 +199,7 @@ class PromptUpdateDetails(Generic[_S]): full: _S """The full content.""" - is_embed: Optional[Callable[[AnyTokenizer, PromptSeq], torch.Tensor]] = None + is_embed: Callable[[AnyTokenizer, PromptSeq], torch.Tensor] | None = None """ Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full], return a boolean mask of shape `(len(full),)` indicating which positions @@ -236,7 +244,7 @@ def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: return PromptUpdateDetails(full=seq, is_embed=is_embed) -PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails] +PromptUpdateInfo: TypeAlias = PromptSeq | PromptUpdateDetails """ The token sequence or text that are part of the update. @@ -245,7 +253,7 @@ def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: specify which part. """ -PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo], PromptUpdateInfo] +PromptUpdateContent: TypeAlias = Callable[[int], PromptUpdateInfo] | PromptUpdateInfo """ Given the index of the processed item within [`modality`][vllm.multimodal.processing.PromptUpdate.modality], @@ -324,8 +332,8 @@ class PromptInsertion(PromptUpdate): Example: - For each image, insert a number of ``<image>`` feature placeholders - equal to the feature size of the vision encoder after the ``<s>`` token: + For each image, insert a number of `<image>` feature placeholders + equal to the feature size of the vision encoder after the `<s>` token: ```python PromptInsertion( @@ -345,7 +353,7 @@ class PromptInsertion(PromptUpdate): ) ``` - Insert these tokens after a prefix ``Images:``: + Insert these tokens after a prefix `Images:`: ```python PromptInsertion( @@ -393,8 +401,8 @@ class PromptReplacement(PromptUpdate): Example: - For each image, replace one ``<image>`` input placeholder in the prompt - with a number of ``<image>`` feature placeholders + For each image, replace one `<image>` input placeholder in the prompt + with a number of `<image>` feature placeholders equal to the feature size of the vision encoder: ```python @@ -405,8 +413,8 @@ class PromptReplacement(PromptUpdate): ) ``` - As above, but further pad the feature placeholders with ``<image_bos>`` - and `<image_eos>``, which are not supposed to be passed to the vision + As above, but further pad the feature placeholders with `<image_bos>` + and `<image_eos>`, which are not supposed to be passed to the vision encoder: ```python @@ -472,12 +480,15 @@ class _HasModalityProp(Protocol): def modality(self) -> str: ... -_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp]) +_M = TypeVar("_M", bound=_HasModalityAttr | _HasModalityProp) def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]: - """Convenience function to apply [`full_groupby`][vllm.utils.full_groupby] - based on modality.""" + """ + Convenience function to apply + [`full_groupby`][vllm.utils.collection_utils.full_groupby] + based on modality. + """ return full_groupby(values, key=lambda x: x.modality) @@ -554,7 +565,7 @@ def iter_text_matches( def iter_matches( self, - prompt: Union[list[int], str], + prompt: list[int] | str, tokenizer: AnyTokenizer, *, start_idx: int = 0, @@ -642,7 +653,7 @@ class PlaceholderFeaturesInfo: item_idx: int start_idx: int tokens: list[int] - is_embed: Optional[torch.Tensor] + is_embed: torch.Tensor | None @property def length(self) -> int: @@ -668,8 +679,8 @@ def _find_matches( *, prev_end_idx: int = 0, current_result: "MultiModalPromptUpdatesApplyResult", -) -> tuple[Optional[UpdateMode], list[_MatchToApply]]: - mode: Optional[UpdateMode] = None +) -> tuple[UpdateMode | None, list[_MatchToApply]]: + mode: UpdateMode | None = None mm_matches = dict[tuple[str, int], tuple[PromptTargetMatch, int]]() for modality, modality_updates in mm_prompt_updates.items(): @@ -723,7 +734,7 @@ def _apply_matches( ) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]: prompt_len = len(prompt) - out_seqs = list[Union[str, list[int]]]() + out_seqs = list[str | list[int]]() out_result: MultiModalPromptUpdatesApplyResult = { m: [None] * len(items) for m, items in mm_prompt_updates.items() } @@ -880,8 +891,8 @@ def find_mm_placeholders( _T = TypeVar("_T") -_C = TypeVar("_C", bound="PretrainedConfig", default="PretrainedConfig") -_P = TypeVar("_P", bound="ProcessorMixin", default="ProcessorMixin") +_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig) +_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin) @dataclass(frozen=True) @@ -891,25 +902,25 @@ class InputProcessingContext: modify the inputs. """ - model_config: "ModelConfig" + model_config: ModelConfig """The configuration of the model.""" tokenizer: AnyTokenizer """The tokenizer used to tokenize the inputs.""" @overload - def get_hf_config(self, /) -> "PretrainedConfig": ... + def get_hf_config(self, /) -> PretrainedConfig: ... @overload def get_hf_config( self, - typ: Union[type[_C], tuple[type[_C], ...]], + typ: type[_C] | tuple[type[_C], ...], /, ) -> _C: ... def get_hf_config( self, - typ: Optional[Union[type[Any], tuple[type[Any], ...]]] = None, + typ: type[Any] | tuple[type[Any], ...] | None = None, /, ) -> Any: """ @@ -955,19 +966,19 @@ def get_mm_config(self): return mm_config @overload - def get_hf_processor(self, /, **kwargs: object) -> "ProcessorMixin": ... + def get_hf_processor(self, /, **kwargs: object) -> ProcessorMixin: ... @overload def get_hf_processor( self, - typ: Union[type[_P], tuple[type[_P], ...]], + typ: type[_P] | tuple[type[_P], ...], /, **kwargs: object, ) -> _P: ... def get_hf_processor( self, - typ: Optional[Union[type[Any], tuple[type[Any], ...]]] = None, + typ: type[Any] | tuple[type[Any], ...] | None = None, /, **kwargs: object, ) -> Any: @@ -1026,13 +1037,13 @@ def _postprocess_one(x: object): def call_hf_processor( self, - hf_processor: "ProcessorMixin", + hf_processor: ProcessorMixin, data: Mapping[str, object], kwargs: Mapping[str, object] = {}, *, num_tries: int = 1, max_tries: int = 5, - ) -> Union["BatchFeature", JSONTree]: + ) -> BatchFeature | JSONTree: """ Call `hf_processor` on the prompt `data` (text, image, audio...) with configurable options `kwargs`. @@ -1113,10 +1124,10 @@ def model_id(self) -> str: def get_tokenizer(self) -> AnyTokenizer: return self.ctx.tokenizer - def get_hf_config(self) -> "PretrainedConfig": + def get_hf_config(self) -> PretrainedConfig: return self.ctx.get_hf_config() - def get_hf_processor(self, **kwargs: object) -> "ProcessorMixin": + def get_hf_processor(self, **kwargs: object) -> ProcessorMixin: """ Subclasses can override this method to handle specific kwargs from model config or user inputs. @@ -1124,7 +1135,7 @@ def get_hf_processor(self, **kwargs: object) -> "ProcessorMixin": return self.ctx.get_hf_processor(**kwargs) @abstractmethod - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: """ Return the maximum supported number of items for each modality. @@ -1156,7 +1167,7 @@ def get_mm_max_tokens_per_item( self, seq_len: int, mm_counts: Mapping[str, int], - ) -> Optional[Mapping[str, int]]: + ) -> Mapping[str, int] | None: """ Return the maximum number of tokens per item of for each modality. @@ -1193,7 +1204,7 @@ def get_mm_max_tokens_per_item( [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. """ -MultiModalPromptUpdatesApplyResult = Mapping[str, list[Optional[int]]] +MultiModalPromptUpdatesApplyResult = Mapping[str, list[int | None]] """ For an item `MultiModalPromptUpdates[k][i]`, `MultiModalPromptUpdatesApplyResult[k][i]` represents the index of the @@ -1220,7 +1231,7 @@ def __init__( info: _I, dummy_inputs: "BaseDummyInputsBuilder[_I]", *, - cache: Optional["BaseMultiModalProcessorCache"] = None, + cache: BaseMultiModalProcessorCache | None = None, ) -> None: super().__init__() @@ -1248,7 +1259,7 @@ def __call__( mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], *, - mm_uuids: Optional[MultiModalUUIDDict] = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> MultiModalInputs: return self.apply(prompt, mm_data, hf_processor_mm_kwargs, mm_uuids=mm_uuids) @@ -1297,6 +1308,16 @@ def _to_mm_items( [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data]. """ mm_items = self.data_parser.parse_mm_data(mm_data) + + mm_config = self.info.ctx.model_config.get_multimodal_config() + if not mm_config.enable_mm_embeds: + for modality, items in mm_items.items(): + if isinstance(items, (EmbeddingItems, DictEmbeddingItems)): + raise ValueError( + f"You must set `--enable-mm-embeds` to input " + f"`{modality}_embeds`" + ) + for modality, items in mm_items.items(): self.validate_num_items(modality, len(items)) @@ -1305,7 +1326,7 @@ def _to_mm_items( @abstractmethod def _get_mm_fields_config( self, - hf_inputs: "BatchFeature", + hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: """Given the HF-processed data, output the metadata of each field.""" @@ -1411,7 +1432,7 @@ def _call_hf_processor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], - ) -> "BatchFeature": + ) -> BatchFeature: """ Call the HF processor on the prompt text and associated multi-modal data. @@ -1447,7 +1468,7 @@ def _apply_hf_processor_text_mm( mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - ) -> tuple[list[int], "BatchFeature", bool]: + ) -> tuple[list[int], BatchFeature, bool]: """ Apply the HF processor on the prompt text and multi-modal data together. @@ -1518,7 +1539,7 @@ def _apply_hf_processor_mm_only( mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - ) -> "BatchFeature": + ) -> BatchFeature: """ Apply the HF processor on the multi-modal data only. @@ -1540,13 +1561,13 @@ def _apply_hf_processor_mm_only( def _apply_hf_processor_main( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], *, enable_hf_prompt_update: bool, - ) -> tuple[list[int], "BatchFeature", bool]: + ) -> tuple[list[int], BatchFeature, bool]: """ Apply the HF processor on the prompt text and multi-modal data. @@ -1585,7 +1606,7 @@ def _hash_mm_items( hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], *, - mm_uuids: Optional[MultiModalUUIDDict] = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> MultiModalHashes: """Create MM hashes to be returned. @@ -1647,7 +1668,7 @@ def _hash_mm_items( def _get_cache_missing_items( self, - cache: "BaseMultiModalProcessorCache", + cache: BaseMultiModalProcessorCache, mm_data_items: MultiModalDataItems, mm_hashes: MultiModalHashes, ) -> MultiModalDataItems: @@ -1692,7 +1713,7 @@ def _recompute_cached_prompt_update( def _merge_mm_kwargs( self, - cache: "BaseMultiModalProcessorCache", + cache: BaseMultiModalProcessorCache, mm_hashes: MultiModalHashes, mm_missing_kwargs: MultiModalKwargsItems, mm_missing_prompt_updates: MultiModalPromptUpdates, @@ -1705,7 +1726,7 @@ def _merge_mm_kwargs( mm_missing_next_idx = defaultdict[str, int](lambda: 0) - merged_kwargs = defaultdict[str, list[Optional[MultiModalKwargsItem]]](list) + merged_kwargs = defaultdict[str, list[MultiModalKwargsItem | None]](list) merged_prompt_updates = defaultdict[str, list[Sequence[ResolvedPromptUpdate]]]( list ) @@ -1714,7 +1735,7 @@ def _merge_mm_kwargs( missing_prompt_updates = mm_missing_prompt_updates.get(modality, []) for item_idx, item_hash in enumerate(hashes): - kwargs: Optional[MultiModalKwargsItem] + kwargs: MultiModalKwargsItem | None if not mm_is_cached[modality][item_idx]: missing_next_idx = mm_missing_next_idx[modality] kwargs = missing_kwargs[missing_next_idx] @@ -1743,12 +1764,12 @@ def _merge_mm_kwargs( def _apply_hf_processor( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], *, - mm_uuids: Optional[MultiModalUUIDDict] = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: ( prompt_ids, @@ -1791,12 +1812,12 @@ def _apply_hf_processor( def _cached_apply_hf_processor( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], *, - mm_uuids: Optional[MultiModalUUIDDict] = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: """ Apply the HF processor on the full prompt text, @@ -2026,12 +2047,12 @@ def _maybe_apply_prompt_updates( def apply( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Optional[Mapping[str, object]] = None, + tokenization_kwargs: Mapping[str, object] | None = None, *, - mm_uuids: Optional[MultiModalUUIDDict] = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> MultiModalInputs: """ Process multi-modal inputs to be used in vLLM. @@ -2090,9 +2111,9 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): @abstractmethod def create_encoder_prompt( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_data: MultiModalDataDict, - ) -> Union[str, list[int]]: + ) -> str | list[int]: """ Create input prompt for the encoder. HF processor will be applied on this prompt during profiling and generation. @@ -2105,15 +2126,15 @@ def pad_dummy_encoder_prompt(self) -> bool: def create_decoder_prompt( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_data: MultiModalDataDict, - ) -> Union[str, list[int]]: + ) -> str | list[int]: """Create input prompt for the decoder.""" return prompt def _get_enc_dec_inputs( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_data: MultiModalDataDict, encoder_inputs: MultiModalInputs, ): @@ -2135,12 +2156,12 @@ def _get_enc_dec_inputs( def apply( self, - prompt: Union[str, list[int]], + prompt: str | list[int], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Optional[Mapping[str, object]] = None, + tokenization_kwargs: Mapping[str, object] | None = None, *, - mm_uuids: Optional[MultiModalUUIDDict] = None, + mm_uuids: MultiModalUUIDDict | None = None, ) -> MultiModalEncDecInputs: """ Process multi-modal inputs to be used in vLLM. diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 05ba5a2abdd4..f55bad569e16 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from dataclasses import dataclass, field -from typing import Generic, NamedTuple, Optional, TypeVar, Union, cast +from typing import Generic, NamedTuple, TypeVar, cast import numpy as np import numpy.typing as npt @@ -41,7 +41,7 @@ class ProcessorInputs: [`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][]. """ - prompt: Union[str, list[int]] + prompt: str | list[int] mm_data: MultiModalDataDict hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) tokenization_kwargs: Mapping[str, object] = field(default_factory=dict) @@ -87,7 +87,7 @@ def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: """ Build the multimodal input which, after processing, results in @@ -107,7 +107,7 @@ def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> ProcessorInputs: """ Build the input which, after processing, results in @@ -136,7 +136,7 @@ def _get_dummy_audios( *, length: int, num_audios: int, - overrides: Optional[AudioDummyOptions] = None, + overrides: AudioDummyOptions | None = None, ) -> list[npt.NDArray]: if num_audios == 0: return [] @@ -158,7 +158,7 @@ def _get_dummy_images( width: int, height: int, num_images: int, - overrides: Optional[ImageDummyOptions] = None, + overrides: ImageDummyOptions | None = None, ) -> list[Image.Image]: if num_images == 0: return [] @@ -191,7 +191,7 @@ def _get_dummy_videos( height: int, num_frames: int, num_videos: int, - overrides: Optional[VideoDummyOptions] = None, + overrides: VideoDummyOptions | None = None, ) -> list[npt.NDArray]: if num_videos == 0: return [] @@ -223,7 +223,7 @@ def _get_dummy_videos( height, ) height = min(height, overrides.height) - video = np.full((num_frames, width, height, 3), 255) + video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) return [video] * num_videos @@ -254,8 +254,8 @@ def get_mm_limits(self) -> Mapping[str, int]: def _get_dummy_mm_inputs( self, seq_len: int, - mm_counts: Optional[Mapping[str, int]] = None, - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_counts: Mapping[str, int] | None = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalInputs: if mm_counts is None: mm_counts = self.get_mm_limits() @@ -290,8 +290,8 @@ def _get_mm_num_tokens( def get_encoder_dummy_data( self, seq_len: int, - mm_counts: Optional[Mapping[str, int]] = None, - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_counts: Mapping[str, int] | None = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> DummyEncoderData: mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options) mm_inputs = cast(MultiModalEncDecInputs, mm_inputs) @@ -324,8 +324,8 @@ def get_encoder_dummy_data( def get_decoder_dummy_data( self, seq_len: int, - mm_counts: Optional[Mapping[str, int]] = None, - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_counts: Mapping[str, int] | None = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> DummyDecoderData: mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options) @@ -344,7 +344,7 @@ def get_decoder_dummy_data( def _get_mm_max_tokens( self, seq_len: int, - mm_counts: Optional[Mapping[str, int]] = None, + mm_counts: Mapping[str, int] | None = None, mm_embeddings_only: bool = True, ) -> Mapping[str, int]: if mm_counts is None: @@ -355,7 +355,11 @@ def _get_mm_max_tokens( mm_counts=mm_counts, ) if max_tokens_per_item is not None: - return max_tokens_per_item + return { + modality: max_tokens + for modality, max_tokens in max_tokens_per_item.items() + if mm_counts.get(modality, 0) > 0 + } mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only) @@ -363,7 +367,7 @@ def _get_mm_max_tokens( def get_mm_max_contiguous_tokens( self, seq_len: int, - mm_counts: Optional[Mapping[str, int]] = None, + mm_counts: Mapping[str, int] | None = None, ): """ Returns the maximum length of the multimodal (image placeholders+text) @@ -375,5 +379,4 @@ def get_mm_max_contiguous_tokens( This is important to take into account when profiling and initializing the encoder cache size. """ - return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index a526eaff715a..8f9276e84640 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -2,14 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar import torch.nn as nn from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config -from vllm.utils import ClassRegistry +from vllm.utils.collection_utils import ClassRegistry from .cache import BaseMultiModalProcessorCache from .processing import ( @@ -69,7 +69,7 @@ def __call__( info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[BaseMultiModalProcessorCache] = None, + cache: BaseMultiModalProcessorCache | None = None, ) -> BaseMultiModalProcessor[_I]: ... @@ -83,7 +83,7 @@ def build_processor( self, ctx: InputProcessingContext, *, - cache: Optional[BaseMultiModalProcessorCache] = None, + cache: BaseMultiModalProcessorCache | None = None, ): info = self.info(ctx) dummy_inputs_builder = self.dummy_inputs(info) @@ -101,7 +101,7 @@ def __init__(self) -> None: def _extract_mm_options( self, model_config: "ModelConfig", - ) -> Optional[Mapping[str, BaseDummyOptions]]: + ) -> Mapping[str, BaseDummyOptions] | None: """ Extract multimodal dummy options from model config. @@ -151,7 +151,8 @@ def get_max_tokens_per_item_by_modality( self, model_config: "ModelConfig", *, - cache: Optional[BaseMultiModalProcessorCache] = None, + cache: BaseMultiModalProcessorCache | None = None, + profiler_limits: Mapping[str, int] | None = None, ) -> Mapping[str, int]: """ Get the maximum number of tokens per data item from each modality based @@ -164,45 +165,20 @@ def get_max_tokens_per_item_by_modality( profiler: MultiModalProfiler = MultiModalProfiler(processor) seq_len = model_config.max_model_len - mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) + profiler_limits = ( + profiler.get_mm_limits() if profiler_limits is None else profiler_limits + ) return profiler.get_mm_max_contiguous_tokens( seq_len, - {modality: 1 for modality, limit in mm_limits.items() if limit > 0}, + {modality: 1 for modality, limit in profiler_limits.items() if limit > 0}, ) - def get_max_tokens_per_item_by_nonzero_modality( - self, - model_config: "ModelConfig", - *, - cache: Optional[BaseMultiModalProcessorCache] = None, - ) -> Mapping[str, int]: - """ - Get the maximum number of tokens per data item from each modality based - on underlying model configuration, excluding modalities that user - explicitly disabled via `limit_mm_per_prompt`. - - Note: - This is currently directly used only in V1 for profiling the memory - usage of a model. - """ - mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) - max_tokens_per_item = self.get_max_tokens_per_item_by_modality( - model_config, - cache=cache, - ) - - return { - key: max_tokens_per_mm_item - for key, max_tokens_per_mm_item in max_tokens_per_item.items() - if mm_limits[key] > 0 - } - def get_mm_limits_per_prompt( self, model_config: "ModelConfig", *, - cache: Optional[BaseMultiModalProcessorCache] = None, + cache: BaseMultiModalProcessorCache | None = None, ) -> Mapping[str, int]: """ Get the maximum number of multi-modal input instances for each modality @@ -259,7 +235,7 @@ def _get_model_cls(self, model_config: "ModelConfig"): def _create_processing_ctx( self, model_config: "ModelConfig", - tokenizer: Optional[AnyTokenizer] = None, + tokenizer: AnyTokenizer | None = None, ) -> InputProcessingContext: if tokenizer is None and not model_config.skip_tokenizer_init: tokenizer = cached_tokenizer_from_config(model_config) @@ -269,7 +245,7 @@ def _create_processing_info( self, model_config: "ModelConfig", *, - tokenizer: Optional[AnyTokenizer] = None, + tokenizer: AnyTokenizer | None = None, ) -> BaseProcessingInfo: model_cls = self._get_model_cls(model_config) factories = self._processor_factories[model_cls] @@ -280,8 +256,8 @@ def create_processor( self, model_config: "ModelConfig", *, - tokenizer: Optional[AnyTokenizer] = None, - cache: Optional[BaseMultiModalProcessorCache] = None, + tokenizer: AnyTokenizer | None = None, + cache: BaseMultiModalProcessorCache | None = None, ) -> BaseMultiModalProcessor[BaseProcessingInfo]: """ Create a multi-modal processor for a specific model and tokenizer. @@ -300,14 +276,14 @@ def get_decoder_dummy_data( self, model_config: "ModelConfig", seq_len: int, - mm_counts: Optional[Mapping[str, int]] = None, + mm_counts: Mapping[str, int] | None = None, *, - cache: Optional[BaseMultiModalProcessorCache] = None, + cache: BaseMultiModalProcessorCache | None = None, ) -> DummyDecoderData: """ Create dummy data for profiling the memory usage of a model. - The model is identified by ``model_config``. + The model is identified by `model_config`. """ processor = self.create_processor(model_config, cache=cache) profiler: MultiModalProfiler = MultiModalProfiler(processor) @@ -333,14 +309,14 @@ def get_encoder_dummy_data( self, model_config: "ModelConfig", seq_len: int, - mm_counts: Optional[Mapping[str, int]] = None, + mm_counts: Mapping[str, int] | None = None, *, - cache: Optional[BaseMultiModalProcessorCache] = None, + cache: BaseMultiModalProcessorCache | None = None, ) -> DummyEncoderData: """ Create dummy data for profiling the memory usage of a model. - The model is identified by ``model_config``. + The model is identified by `model_config`. """ processor = self.create_processor(model_config, cache=cache) profiler: MultiModalProfiler = MultiModalProfiler(processor) @@ -369,7 +345,7 @@ def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int: """ if not model_config.is_encoder_decoder: return 0 - max_tokens = self.get_max_tokens_per_item_by_nonzero_modality(model_config) + max_tokens = self.get_max_tokens_per_item_by_modality(model_config) if not max_tokens: # TODO - this function assumes encoder-decoder models are # multimodal. This will need to change when adding support for more diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index c9dc077d0385..e97bab250ed1 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor from itertools import groupby from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, TypeVar from urllib.parse import ParseResult, urlparse from urllib.request import url2pathname @@ -31,13 +31,11 @@ from .inputs import ( BatchedTensorInputs, MultiModalKwargsItem, - MultiModalKwargsItems, MultiModalPlaceholderDict, ) else: BatchedTensorInputs = Any MultiModalKwargsItem = Any - MultiModalKwargsItems = Any MultiModalPlaceholderDict = Any global_thread_pool = ThreadPoolExecutor( @@ -49,11 +47,11 @@ class MediaConnector: def __init__( self, - media_io_kwargs: Optional[dict[str, dict[str, Any]]] = None, + media_io_kwargs: dict[str, dict[str, Any]] | None = None, connection: HTTPConnection = global_http_connection, *, allowed_local_media_path: str = "", - allowed_media_domains: Optional[list[str]] = None, + allowed_media_domains: list[str] | None = None, ) -> None: """ Args: @@ -143,7 +141,7 @@ def load_from_url( url: str, media_io: MediaIO[_M], *, - fetch_timeout: Optional[int] = None, + fetch_timeout: int | None = None, ) -> _M: # type: ignore[type-var] url_spec = urlparse(url) @@ -173,7 +171,7 @@ async def load_from_url_async( url: str, media_io: MediaIO[_M], *, - fetch_timeout: Optional[int] = None, + fetch_timeout: int | None = None, ) -> _M: url_spec = urlparse(url) loop = asyncio.get_running_loop() @@ -207,7 +205,7 @@ async def load_from_url_async( def fetch_audio( self, audio_url: str, - ) -> tuple[np.ndarray, Union[int, float]]: + ) -> tuple[np.ndarray, int | float]: """ Load audio from a URL. """ @@ -222,7 +220,7 @@ def fetch_audio( async def fetch_audio_async( self, audio_url: str, - ) -> tuple[np.ndarray, Union[int, float]]: + ) -> tuple[np.ndarray, int | float]: """ Asynchronously fetch audio from a URL. """ @@ -396,7 +394,7 @@ def group_mm_kwargs_by_modality( *, device: torch.types.Device = None, pin_memory: bool = False, - merge_by_field_config: Optional[bool] = None, + merge_by_field_config: bool | None = None, ) -> Iterable[tuple[str, int, BatchedTensorInputs]]: """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same modality together into the same `MultiModalKwargs` instance. @@ -432,7 +430,7 @@ def group_mm_kwargs_by_modality( if device is not None: mm_kwargs_group = json_map_leaves( - lambda x: x.to(device=device), + lambda x: x.to(device=device) if isinstance(x, torch.Tensor) else x, mm_kwargs_group, ) else: @@ -452,8 +450,8 @@ def group_mm_kwargs_by_modality( def fetch_audio( audio_url: str, - audio_io_kwargs: Optional[dict[str, Any]] = None, -) -> tuple[np.ndarray, Union[int, float]]: + audio_io_kwargs: dict[str, Any] | None = None, +) -> tuple[np.ndarray, int | float]: """ Args: audio_url: URL of the audio file to fetch. @@ -466,7 +464,7 @@ def fetch_audio( def fetch_image( image_url: str, - image_io_kwargs: Optional[dict[str, Any]] = None, + image_io_kwargs: dict[str, Any] | None = None, ) -> Image.Image: """ Args: @@ -480,7 +478,7 @@ def fetch_image( def fetch_video( video_url: str, - video_io_kwargs: Optional[dict[str, Any]] = None, + video_io_kwargs: dict[str, Any] | None = None, ) -> tuple[npt.NDArray, dict[str, Any]]: """ Args: diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 400d6a6be9be..666ef275a924 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -6,17 +6,20 @@ from functools import partial from io import BytesIO from pathlib import Path -from typing import Any, Union +from typing import Any import numpy as np import numpy.typing as npt from PIL import Image from vllm import envs +from vllm.logger import init_logger from .base import MediaIO from .image import ImageMediaIO +logger = init_logger(__name__) + def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray: num_frames, _, _, channels = frames.shape @@ -103,6 +106,7 @@ def load_bytes( cls, data: bytes, num_frames: int = -1, + fps: int = -1, **kwargs, ) -> tuple[npt.NDArray, dict[str, Any]]: import cv2 @@ -116,14 +120,20 @@ def load_bytes( original_fps = cap.get(cv2.CAP_PROP_FPS) duration = total_frames_num / original_fps if original_fps > 0 else 0 - # resample video to target num_frames - full_read = num_frames == -1 or total_frames_num < num_frames - if full_read: - num_frames = total_frames_num - frame_idx = list(range(0, num_frames)) + # resample video to target num_frames and fps + # - the minimum of the two will be used + num_frames_to_sample = total_frames_num + if num_frames > 0: + num_frames_to_sample = min(num_frames, total_frames_num) + if fps > 0: + num_frames_to_sample = min(num_frames_to_sample, math.floor(duration * fps)) + num_frames_to_sample = max(1, num_frames_to_sample) # at least one sample + + if num_frames_to_sample == total_frames_num: + frame_idx = list(range(0, num_frames_to_sample)) else: uniform_sampled_frames = np.linspace( - 0, total_frames_num - 1, num_frames, dtype=int + 0, total_frames_num - 1, num_frames_to_sample, dtype=int ) frame_idx = uniform_sampled_frames.tolist() @@ -132,7 +142,7 @@ def load_bytes( frames = np.empty((len(frame_idx), height, width, 3), dtype=np.uint8) i = 0 - for idx in range(total_frames_num): + for idx in range(max(frame_idx) + 1): ok = cap.grab() if not ok: break @@ -142,8 +152,8 @@ def load_bytes( frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) i += 1 - assert i == num_frames, ( - f"Expected reading {num_frames} frames, " + assert i == num_frames_to_sample, ( + f"Expected reading {num_frames_to_sample} frames, " f"but only loaded {i} frames from video." ) @@ -151,14 +161,14 @@ def load_bytes( # NOTE(Isotr0py): For models like Qwen3-VL/GLM4.5V, this metadata # can cause incorrect timestamp calculation without num_frames=-1. metadata = { - "total_num_frames": num_frames, - "fps": num_frames / duration, + "total_num_frames": total_frames_num, + "fps": original_fps, "duration": duration, "video_backend": "opencv", - "frames_indices": list(range(num_frames)), + "frames_indices": list(frame_idx), # extra field used to control hf processor's video # sampling behavior - "do_sample_frames": num_frames == total_frames_num, + "do_sample_frames": num_frames_to_sample == total_frames_num, } return frames, metadata @@ -192,7 +202,7 @@ def load_bytes( # Refer to: # https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/glm4v/video_processing_glm4v.py#L103-L140 - frame_indices: Union[range, list[int]] + frame_indices: range | list[int] if duration <= max_duration: n = int(math.floor(duration * fps)) frame_indices = sorted( diff --git a/vllm/outputs.py b/vllm/outputs.py index dc183bd8dbe9..cdfe06f1c7fa 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -4,7 +4,7 @@ from collections.abc import MutableSequence from collections.abc import Sequence as GenericSequence from dataclasses import dataclass -from typing import Any, Generic, Optional, Union +from typing import Any, Generic import torch from typing_extensions import TypeVar @@ -41,11 +41,11 @@ class CompletionOutput: index: int text: str token_ids: GenericSequence[int] - cumulative_logprob: Optional[float] - logprobs: Optional[SampleLogprobs] - finish_reason: Optional[str] = None - stop_reason: Union[int, str, None] = None - lora_request: Optional[LoRARequest] = None + cumulative_logprob: float | None + logprobs: SampleLogprobs | None + finish_reason: str | None = None + stop_reason: int | str | None = None + lora_request: LoRARequest | None = None def finished(self) -> bool: return self.finish_reason is not None @@ -108,19 +108,19 @@ class RequestOutput: def __init__( self, request_id: str, - prompt: Optional[str], - prompt_token_ids: Optional[list[int]], - prompt_logprobs: Optional[PromptLogprobs], + prompt: str | None, + prompt_token_ids: list[int] | None, + prompt_logprobs: PromptLogprobs | None, outputs: list[CompletionOutput], finished: bool, - metrics: Optional[Union[RequestMetrics, RequestStateStats]] = None, - lora_request: Optional[LoRARequest] = None, - encoder_prompt: Optional[str] = None, - encoder_prompt_token_ids: Optional[list[int]] = None, - num_cached_tokens: Optional[int] = None, + metrics: RequestMetrics | RequestStateStats | None = None, + lora_request: LoRARequest | None = None, + encoder_prompt: str | None = None, + encoder_prompt_token_ids: list[int] | None = None, + num_cached_tokens: int | None = None, *, - multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None, - kv_transfer_params: Optional[dict[str, Any]] = None, + multi_modal_placeholders: MultiModalPlaceholderDict | None = None, + kv_transfer_params: dict[str, Any] | None = None, # Forward compatibility, code that uses args added in new release can # still run with older versions of vLLM without breaking. **kwargs: Any, @@ -201,14 +201,21 @@ class PoolingRequestOutput(Generic[_O]): request_id (str): A unique identifier for the pooling request. outputs (PoolingOutput): The pooling results for the given input. prompt_token_ids (list[int]): A list of token IDs used in the prompt. + num_cached_tokens: The number of tokens with prefix cache hit. finished (bool): A flag indicating whether the pooling is completed. """ def __init__( - self, request_id: str, outputs: _O, prompt_token_ids: list[int], finished: bool + self, + request_id: str, + outputs: _O, + prompt_token_ids: list[int], + num_cached_tokens: int, + finished: bool, ): self.request_id = request_id self.prompt_token_ids = prompt_token_ids + self.num_cached_tokens = num_cached_tokens self.finished = finished self.outputs = outputs @@ -217,6 +224,7 @@ def __repr__(self): f"{type(self).__name__}(request_id={self.request_id!r}, " f"outputs={self.outputs!r}, " f"prompt_token_ids={self.prompt_token_ids}, " + f"num_cached_tokens={self.num_cached_tokens}, " f"finished={self.finished})" ) @@ -255,6 +263,7 @@ def from_base(request_output: PoolingRequestOutput): request_id=request_output.request_id, outputs=EmbeddingOutput.from_base(request_output.outputs), prompt_token_ids=request_output.prompt_token_ids, + num_cached_tokens=request_output.num_cached_tokens, finished=request_output.finished, ) @@ -294,6 +303,7 @@ def from_base(request_output: PoolingRequestOutput): request_id=request_output.request_id, outputs=ClassificationOutput.from_base(request_output.outputs), prompt_token_ids=request_output.prompt_token_ids, + num_cached_tokens=request_output.num_cached_tokens, finished=request_output.finished, ) @@ -330,5 +340,6 @@ def from_base(request_output: PoolingRequestOutput): request_id=request_output.request_id, outputs=ScoringOutput.from_base(request_output.outputs), prompt_token_ids=request_output.prompt_token_ids, + num_cached_tokens=request_output.num_cached_tokens, finished=request_output.finished, ) diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 962e1323b721..f64d7a010b5f 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -3,11 +3,12 @@ import logging import traceback from itertools import chain -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from vllm import envs -from vllm.plugins import load_plugins_by_group -from vllm.utils import resolve_obj_by_qualname, supports_xccl +from vllm.plugins import PLATFORM_PLUGINS_GROUP, load_plugins_by_group +from vllm.utils.import_utils import resolve_obj_by_qualname +from vllm.utils.torch_utils import supports_xccl from .interface import CpuArchEnum, Platform, PlatformEnum @@ -31,7 +32,7 @@ def vllm_version_matches_substr(substr: str) -> bool: return substr in vllm_version -def tpu_platform_plugin() -> Optional[str]: +def tpu_platform_plugin() -> str | None: logger.debug("Checking if TPU platform is available.") # Check for Pathways TPU proxy @@ -55,7 +56,7 @@ def tpu_platform_plugin() -> Optional[str]: return None -def cuda_platform_plugin() -> Optional[str]: +def cuda_platform_plugin() -> str | None: is_cuda = False logger.debug("Checking if CUDA platform is available.") try: @@ -106,7 +107,7 @@ def cuda_is_jetson() -> bool: return "vllm.platforms.cuda.CudaPlatform" if is_cuda else None -def rocm_platform_plugin() -> Optional[str]: +def rocm_platform_plugin() -> str | None: is_rocm = False logger.debug("Checking if ROCm platform is available.") try: @@ -127,7 +128,7 @@ def rocm_platform_plugin() -> Optional[str]: return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None -def xpu_platform_plugin() -> Optional[str]: +def xpu_platform_plugin() -> str | None: is_xpu = False logger.debug("Checking if XPU platform is available.") try: @@ -154,7 +155,7 @@ def xpu_platform_plugin() -> Optional[str]: return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None -def cpu_platform_plugin() -> Optional[str]: +def cpu_platform_plugin() -> str | None: is_cpu = False logger.debug("Checking if CPU platform is available.") try: @@ -188,7 +189,7 @@ def cpu_platform_plugin() -> Optional[str]: def resolve_current_platform_cls_qualname() -> str: - platform_plugins = load_plugins_by_group("vllm.platform_plugins") + platform_plugins = load_plugins_by_group(PLATFORM_PLUGINS_GROUP) activated_plugins = [] @@ -221,10 +222,12 @@ def resolve_current_platform_cls_qualname() -> str: ) elif len(activated_builtin_plugins) == 1: platform_cls_qualname = builtin_platform_plugins[activated_builtin_plugins[0]]() - logger.info("Automatically detected platform %s.", activated_builtin_plugins[0]) + logger.debug( + "Automatically detected platform %s.", activated_builtin_plugins[0] + ) else: platform_cls_qualname = "vllm.platforms.interface.UnspecifiedPlatform" - logger.info("No platform detected, vLLM is running on UnspecifiedPlatform") + logger.debug("No platform detected, vLLM is running on UnspecifiedPlatform") return platform_cls_qualname @@ -261,4 +264,14 @@ def __getattr__(name: str): raise AttributeError(f"No attribute named '{name}' exists in {__name__}.") +def __setattr__(name: str, value): + if name == "current_platform": + global _current_platform + _current_platform = value + elif name in globals(): + globals()[name] = value + else: + raise AttributeError(f"No attribute named '{name}' exists in {__name__}.") + + __all__ = ["Platform", "PlatformEnum", "current_platform", "CpuArchEnum", "_init_trace"] diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 24e08a8ecbd7..699a56be5cc4 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -8,8 +8,9 @@ import sys from dataclasses import dataclass from importlib.util import find_spec -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING +import regex as re import torch from vllm.logger import init_logger @@ -127,7 +128,7 @@ def get_attn_backend_cls( selected_backend: "_Backend", head_size: int, dtype: torch.dtype, - kv_cache_dtype: Optional[str], + kv_cache_dtype: str | None, block_size: int, use_v1: bool, use_mla: bool, @@ -150,7 +151,7 @@ def get_attn_backend_cls( @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: import vllm.envs as envs - from vllm.utils import GiB_bytes + from vllm.utils.mem_constants import GiB_bytes kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE if kv_cache_space is None: @@ -246,12 +247,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config.enable_dbo = False # Note: workaround for v1 gpu_model_runner - from vllm.config import CompilationLevel + from vllm.config import CompilationMode vllm_config.compilation_config.cudagraph_capture_sizes = [] compilation_config = vllm_config.compilation_config - if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE: + if vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE: # Note: vLLM V1 is using PIECEWISE level compilation, which will # take time to compile kernels just-in-time with the inductor # backend. For CPU CI tests, most of them are executed fast and @@ -264,7 +265,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: else: backend = "inductor" - compilation_config.level = CompilationLevel.DYNAMO_ONCE + compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE compilation_config.backend = backend compilation_config.inductor_compile_config.update( { @@ -276,7 +277,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: ) if vllm_config.lora_config is not None: - compilation_config.level = CompilationLevel.NO_COMPILATION + compilation_config.mode = CompilationMode.NONE assert vllm_config.device_config.device_type == "cpu" @@ -296,6 +297,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # Disable torch async compiling which won't work with daemonic processes os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" + # Disable multi-stream for shared experts as no Stream on CPU + os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "0" + # Intel OpenMP setting ld_prealod_str = os.getenv("LD_PRELOAD", "") if "libiomp5.so" in ld_prealod_str: @@ -334,6 +338,7 @@ def get_allowed_cpu_core_node_list(cls) -> tuple[list[int], list[LogicalCPUInfo] lscpu_output = subprocess.check_output( "lscpu -J -e=CPU,CORE,NODE", shell=True, text=True ) + lscpu_output = re.sub(r'"node":\s*-\s*(,|\n)', r'"node": 0\1', lscpu_output) logical_cpu_list: list[LogicalCPUInfo] = json.loads( lscpu_output, object_hook=LogicalCPUInfo.json_decoder )["cpus"] diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 8a4565b4d1a0..637f35a4920e 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -5,26 +5,25 @@ """ import os -from datetime import timedelta +from collections.abc import Callable from functools import cache, wraps -from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union +from typing import TYPE_CHECKING, TypeVar import torch -from torch.distributed import PrefixStore, ProcessGroup -from torch.distributed.distributed_c10d import is_nccl_available from typing_extensions import ParamSpec # import custom ops, trigger op registration import vllm._C # noqa import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import cuda_device_count_stateless, import_pynvml +from vllm.utils import import_pynvml +from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: from vllm.attention.backends.registry import _Backend - from vllm.config import ModelConfig, VllmConfig + from vllm.config import VllmConfig else: _Backend = None @@ -85,7 +84,7 @@ def set_device(cls, device: torch.device) -> None: _ = torch.zeros(1, device=device) @classmethod - def get_device_capability(cls, device_id: int = 0) -> Optional[DeviceCapability]: + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: raise NotImplementedError @classmethod @@ -118,7 +117,15 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: # TODO(lucas): handle this more gracefully # Note: model_config may be None during testing - if model_config is not None and model_config.use_mla: + # Note: block_size is initialized in + # HybridAttentionMambaModelConfig.verify_and_update_config + # for models with both attention and mamba, + # and doesn't need to be reinitialized here + if ( + model_config is not None + and model_config.use_mla + and cache_config.block_size is not None + ): use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, # then we default to FlashMLA backend for non-blackwell GPUs, @@ -151,18 +158,22 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: if ( use_flashmla and is_flashmla_dense_supported()[0] - and cache_config.block_size != 64 + and cache_config.block_size % 64 != 0 ): cache_config.block_size = 64 logger.info("Forcing kv cache block size to 64 for FlashMLA backend.") - if use_cutlass_mla and cache_config.block_size != 128: + if use_cutlass_mla and cache_config.block_size % 128 != 0: cache_config.block_size = 128 logger.info( "Forcing kv cache block size to 128 for CUTLASS_MLA backend." ) - if use_flashinfer_mla and cache_config.block_size not in [32, 64]: + if ( + use_flashinfer_mla + and cache_config.block_size != 32 + and cache_config.block_size % 64 != 0 + ): cache_config.block_size = 64 logger.info( "Forcing kv cache block size to 64 for FlashInferMLA backend." @@ -179,7 +190,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: compilation_config = vllm_config.compilation_config if ( - envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" + parallel_config.all2all_backend == "deepep_high_throughput" and parallel_config.data_parallel_size > 1 and compilation_config.cudagraph_mode != CUDAGraphMode.NONE ): @@ -191,14 +202,14 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: "kernels are optimized for prefill and are incompatible with " "CUDA Graphs. " "In order to use CUDA Graphs for decode-optimized workloads, " - "set VLLM_ALL2ALL_BACKEND to another option, such as " + "use --all2all-backend with another option, such as " "deepep_low_latency, pplx, or allgather_reducescatter." ) compilation_config.cudagraph_mode = CUDAGraphMode.NONE @classmethod def get_current_memory_usage( - cls, device: Optional[torch.types.Device] = None + cls, device: torch.types.Device | None = None ) -> float: torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats(device) @@ -269,12 +280,12 @@ def get_attn_backend_cls( use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( selected_backend is None and cls.is_device_capability(100) - and block_size == 128 + and block_size % 128 == 0 ) use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or ( selected_backend is None and cls.is_device_capability(100) - and block_size in [32, 64] + and (block_size == 32 or block_size % 64 == 0) ) use_flashmla = selected_backend == _Backend.FLASHMLA or ( selected_backend is None and is_flashmla_dense_supported()[0] @@ -287,7 +298,9 @@ def get_attn_backend_cls( ) if use_cutlassmla: - logger.info_once("Using Cutlass MLA backend on V1 engine.") + logger.info_once( + "Using Cutlass MLA backend on V1 engine.", scope="local" + ) return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend" if use_flashinfermla: from vllm.v1.attention.backends.utils import set_kv_cache_layout @@ -298,7 +311,7 @@ def get_attn_backend_cls( "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" ) if use_flashmla: - if block_size != 64: + if block_size % 64 != 0: logger.warning( "FlashMLA backend is not supported for block size %d" " (currently only supports block size 64).", @@ -442,87 +455,13 @@ def opaque_attention_op(cls) -> bool: def get_static_graph_wrapper_cls(cls) -> str: return "vllm.compilation.cuda_graph.CUDAGraphWrapper" - @classmethod - def stateless_init_device_torch_dist_pg( - cls, - backend: str, - prefix_store: PrefixStore, - group_rank: int, - group_size: int, - timeout: timedelta, - ) -> ProcessGroup: - assert is_nccl_available() - pg: ProcessGroup = ProcessGroup( - prefix_store, - group_rank, - group_size, - ) - from torch.distributed.distributed_c10d import ProcessGroupNCCL - - backend_options = ProcessGroupNCCL.Options() - backend_options._timeout = timeout - - backend_class = ProcessGroupNCCL( - prefix_store, group_rank, group_size, backend_options - ) - backend_type = ProcessGroup.BackendType.NCCL - device = torch.device("cuda") - pg._set_default_backend(backend_type) - backend_class._set_sequence_number_for_group() - - pg._register_backend(device, backend_type, backend_class) - return pg - @classmethod def device_count(cls) -> int: return cuda_device_count_stateless() @classmethod - def is_kv_cache_dtype_supported( - cls, kv_cache_dtype: str, model_config: "ModelConfig" - ) -> bool: - fp8_attention = kv_cache_dtype.startswith("fp8") - attention_backend = envs.VLLM_ATTENTION_BACKEND - - supported = False - if model_config is not None and model_config.use_mla: - # Default to CutlassMLA for blackwell, - # FlashMLA otherwise - if attention_backend is None: - if cls.is_device_capability(100): - attention_backend = "CUTLASS_MLA" - else: - attention_backend = "FLASHMLA" - - # Only FlashMLA and CUTLASS_MLA support fp8 - if attention_backend in ["FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"]: - supported = True - else: - supported = not fp8_attention - else: - # Default to FlashAttention - if attention_backend is None: - attention_backend = "FLASH_ATTN" - - # All Blackwell backends support fp8 - if cls.is_device_capability(100): - supported = True - elif attention_backend == "FLASH_ATTN": - if fp8_attention: - from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 - - supported = flash_attn_supports_fp8() - else: - supported = True - elif attention_backend == "FLASHINFER": - supported = True - elif attention_backend == "TRITON_ATTN": - supported = cls.supports_fp8() - return supported - - @classmethod - def check_if_supports_dtype(cls, torch_dtype: torch.dtype): - if torch_dtype == torch.bfloat16: # noqa: SIM102 + def check_if_supports_dtype(cls, dtype: torch.dtype): + if dtype == torch.bfloat16: # noqa: SIM102 if not cls.has_device_capability(80): capability = cls.get_device_capability() gpu_name = cls.get_device_name() @@ -582,7 +521,7 @@ class NvmlCudaPlatform(CudaPlatformBase): @classmethod @cache @with_nvml_context - def get_device_capability(cls, device_id: int = 0) -> Optional[DeviceCapability]: + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: try: physical_device_id = cls.device_id_to_physical_device_id(device_id) handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) @@ -595,7 +534,7 @@ def get_device_capability(cls, device_id: int = 0) -> Optional[DeviceCapability] @with_nvml_context def has_device_capability( cls, - capability: Union[tuple[int, int], int], + capability: tuple[int, int] | int, device_id: int = 0, ) -> bool: try: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 59bc9173958c..098e9058f529 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -7,38 +7,31 @@ import random import sys from datetime import timedelta -from platform import uname -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, Any, NamedTuple import numpy as np import torch -from torch.distributed import PrefixStore, ProcessGroup -from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger if TYPE_CHECKING: + from torch.distributed import PrefixStore, ProcessGroup + from vllm.attention.backends.registry import _Backend - from vllm.config import ModelConfig, VllmConfig - from vllm.lora.request import LoRARequest + from vllm.config import VllmConfig + from vllm.inputs import ProcessorInputs, PromptType from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import FlexibleArgumentParser else: - _Backend = None - ModelConfig = None - VllmConfig = None - LoRARequest = None - PoolingParams = None - SamplingParams = None - FlexibleArgumentParser = None + FlexibleArgumentParser = object logger = init_logger(__name__) def in_wsl() -> bool: # Reference: https://github.com/microsoft/WSL/issues/4071 - return "microsoft" in " ".join(uname()).lower() + return "microsoft" in " ".join(platform.uname()).lower() class PlatformEnum(enum.Enum): @@ -113,7 +106,7 @@ class Platform: additional_env_vars: list[str] = [] - _global_graph_pool: Optional[Any] = None + _global_graph_pool: Any | None = None @property def supported_dtypes(self) -> list[torch.dtype]: @@ -141,6 +134,9 @@ def is_cpu(self) -> bool: def is_out_of_tree(self) -> bool: return self._enum == PlatformEnum.OOT + def is_unspecified(self) -> bool: + return self._enum == PlatformEnum.UNSPECIFIED + def get_max_output_tokens(self, prompt_len: int) -> int: return sys.maxsize @@ -167,24 +163,18 @@ def device_id_to_physical_device_id(cls, device_id: int): return device_id @classmethod - def import_core_kernels(cls) -> None: + def import_kernels(cls) -> None: """Import any platform-specific C kernels.""" try: import vllm._C # noqa: F401 except ImportError as e: logger.warning("Failed to import from vllm._C: %r", e) - - @classmethod - def try_import_moe_kernels(cls) -> bool: - """Import any platform-specific MoE kernels.""" with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 - return True - return False - @classmethod def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": + # Import _Backend here to avoid circular import. from vllm.attention.backends.registry import _Backend return _Backend.TORCH_SDPA @@ -195,7 +185,7 @@ def get_attn_backend_cls( selected_backend: "_Backend", head_size: int, dtype: torch.dtype, - kv_cache_dtype: Optional[str], + kv_cache_dtype: str | None, block_size: int, use_v1: bool, use_mla: bool, @@ -209,14 +199,14 @@ def get_attn_backend_cls( def get_device_capability( cls, device_id: int = 0, - ) -> Optional[DeviceCapability]: + ) -> DeviceCapability | None: """Stateless version of [torch.cuda.get_device_capability][].""" return None @classmethod def has_device_capability( cls, - capability: Union[tuple[int, int], int], + capability: tuple[int, int] | int, device_id: int = 0, ) -> bool: """ @@ -240,7 +230,7 @@ def has_device_capability( @classmethod def is_device_capability( cls, - capability: Union[tuple[int, int], int], + capability: tuple[int, int] | int, device_id: int = 0, ) -> bool: """ @@ -287,7 +277,7 @@ def inference_mode(cls): return torch.inference_mode(mode=True) @classmethod - def seed_everything(cls, seed: Optional[int] = None) -> None: + def seed_everything(cls, seed: int | None = None) -> None: """ Set the seed of each random module. `torch.manual_seed` will set seed on all devices. @@ -308,7 +298,7 @@ def set_device(cls, device: torch.device) -> None: @classmethod def pre_register_and_update( - cls, parser: Optional[FlexibleArgumentParser] = None + cls, parser: FlexibleArgumentParser | None = None ) -> None: """ Do some pre-registration or update action for the current platform. @@ -323,7 +313,7 @@ def pre_register_and_update( pass @classmethod - def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: """ Check and update the configuration for the current platform. @@ -393,7 +383,7 @@ def is_pin_memory_available(cls) -> bool: @classmethod def get_current_memory_usage( - cls, device: Optional[torch.types.Device] = None + cls, device: torch.types.Device | None = None ) -> float: """ Return the memory usage in bytes. @@ -504,9 +494,9 @@ def opaque_attention_op(cls) -> bool: @classmethod def validate_request( cls, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - processed_inputs: ProcessorInputs, + prompt: "PromptType", + params: "SamplingParams | PoolingParams", + processed_inputs: "ProcessorInputs", ) -> None: """Raises if this request is unsupported on this platform""" @@ -549,27 +539,18 @@ def get_static_graph_wrapper_cls(cls) -> str: def stateless_init_device_torch_dist_pg( cls, backend: str, - prefix_store: PrefixStore, + prefix_store: "PrefixStore", group_rank: int, group_size: int, timeout: timedelta, - ) -> ProcessGroup: + ) -> "ProcessGroup": """ Init platform-specific torch distributed process group. """ - raise RuntimeError(f"Unsupported torch distributed backend: {backend}") - - @classmethod - def is_kv_cache_dtype_supported( - cls, kv_cache_dtype: str, model_config: "ModelConfig" - ) -> bool: - """ - Returns if the kv_cache_dtype is supported by the current platform. - """ - return False + raise NotImplementedError @classmethod - def check_if_supports_dtype(cls, torch_dtype: torch.dtype): + def check_if_supports_dtype(cls, dtype: torch.dtype): """ Check if the dtype is supported by the current platform. """ @@ -621,7 +602,7 @@ def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]: return {} @classmethod - def get_nixl_memory_type(cls) -> Optional[str]: + def get_nixl_memory_type(cls) -> str | None: """ Returns the nixl memory type for the current platform. """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 25601011491f..b2ec40849446 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -2,23 +2,20 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from datetime import timedelta from functools import cache, lru_cache, wraps -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import torch -from torch.distributed import PrefixStore, ProcessGroup -from torch.distributed.distributed_c10d import is_nccl_available import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: from vllm.attention.backends.registry import _Backend - from vllm.config import ModelConfig, VllmConfig + from vllm.config import VllmConfig else: _Backend = None @@ -81,7 +78,7 @@ "0x74bd": "AMD_Instinct_MI300X_HF", } -# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`` +# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES` if "HIP_VISIBLE_DEVICES" in os.environ: val = os.environ["HIP_VISIBLE_DEVICES"] if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None): @@ -140,8 +137,8 @@ def use_rocm_custom_paged_attention( max_seq_len: int, sliding_window: int, kv_cache_dtype: str, - alibi_slopes: Optional[torch.Tensor] = None, - sinks: Optional[torch.Tensor] = None, + alibi_slopes: torch.Tensor | None = None, + sinks: torch.Tensor | None = None, ) -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) @@ -276,6 +273,9 @@ def get_attn_backend_cls( ) if envs.VLLM_USE_V1: + if selected_backend == _Backend.FLEX_ATTENTION: + logger.info("Using FlexAttention backend on V1 engine.") + return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" if ( envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9() ) or selected_backend == _Backend.ROCM_AITER_FA: @@ -317,7 +317,7 @@ def set_device(cls, device: torch.device) -> None: @classmethod @lru_cache(maxsize=8) - def get_device_capability(cls, device_id: int = 0) -> Optional[DeviceCapability]: + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: major, minor = torch.cuda.get_device_capability(device_id) return DeviceCapability(major=major, minor=minor) @@ -417,7 +417,7 @@ def get_punica_wrapper(cls) -> str: @classmethod def get_current_memory_usage( - cls, device: Optional[torch.types.Device] = None + cls, device: torch.types.Device | None = None ) -> float: torch.cuda.reset_peak_memory_stats(device) return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0] @@ -473,50 +473,13 @@ def is_navi(cls) -> bool: def get_static_graph_wrapper_cls(cls) -> str: return "vllm.compilation.cuda_graph.CUDAGraphWrapper" - @classmethod - def stateless_init_device_torch_dist_pg( - cls, - backend: str, - prefix_store: PrefixStore, - group_rank: int, - group_size: int, - timeout: timedelta, - ) -> ProcessGroup: - assert is_nccl_available() - pg: ProcessGroup = ProcessGroup( - prefix_store, - group_rank, - group_size, - ) - from torch.distributed.distributed_c10d import ProcessGroupNCCL - - backend_options = ProcessGroupNCCL.Options() - backend_options._timeout = timeout - - backend_class = ProcessGroupNCCL( - prefix_store, group_rank, group_size, backend_options - ) - backend_type = ProcessGroup.BackendType.NCCL - device = torch.device("cuda") - pg._set_default_backend(backend_type) - backend_class._set_sequence_number_for_group() - - pg._register_backend(device, backend_type, backend_class) - return pg - @classmethod def device_count(cls) -> int: return cuda_device_count_stateless() @classmethod - def is_kv_cache_dtype_supported( - cls, kv_cache_dtype: str, model_config: "ModelConfig" - ) -> bool: - return True - - @classmethod - def check_if_supports_dtype(cls, torch_dtype: torch.dtype): - if torch_dtype == torch.bfloat16: # noqa: SIM102 + def check_if_supports_dtype(cls, dtype: torch.dtype): + if dtype == torch.bfloat16: # noqa: SIM102 if not cls.has_device_capability(80): capability = cls.get_device_capability() gpu_name = cls.get_device_name() diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index c0888247f593..ab752f438f72 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Optional, Union, cast +import contextlib +from typing import TYPE_CHECKING, cast import torch from tpu_info import device @@ -15,7 +16,8 @@ if TYPE_CHECKING: from vllm.attention.backends.registry import _Backend - from vllm.config import BlockSize, ModelConfig, VllmConfig + from vllm.config import ModelConfig, VllmConfig + from vllm.config.cache import BlockSize from vllm.pooling_params import PoolingParams else: BlockSize = None @@ -44,8 +46,10 @@ class TpuPlatform(Platform): additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"] @classmethod - def import_core_kernels(cls) -> None: - pass + def import_kernels(cls) -> None: + # Do not import vllm._C + with contextlib.suppress(ImportError): + import vllm._moe_C # noqa: F401 @classmethod def get_attn_backend_cls( @@ -53,7 +57,7 @@ def get_attn_backend_cls( selected_backend: "_Backend", head_size: int, dtype: torch.dtype, - kv_cache_dtype: Optional[str], + kv_cache_dtype: str | None, block_size: int, use_v1: bool, use_mla: bool, @@ -110,7 +114,7 @@ def inference_mode(cls): @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - from vllm.config import CompilationLevel, CUDAGraphMode + from vllm.config import CompilationMode, CUDAGraphMode cache_config = vllm_config.cache_config # For v0, the default block size is 16. @@ -118,12 +122,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config.block_size = cast(BlockSize, 16) compilation_config = vllm_config.compilation_config - # TPU only supports DYNAMO_ONCE compilation level - if compilation_config.level != CompilationLevel.DYNAMO_ONCE: + # TPU only supports DYNAMO_TRACE_ONCE compilation mode + if compilation_config.mode != CompilationMode.DYNAMO_TRACE_ONCE: logger.info( - "[TPU] Forcing DYNAMO_ONCE compilation level, and disabling cudagraph." + "[TPU] Forcing DYNAMO_TRACE_ONCE compilation mode, and\ + disabling cudagraph." ) - compilation_config.level = CompilationLevel.DYNAMO_ONCE + compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE if ( compilation_config.cudagraph_mode is None @@ -207,7 +212,7 @@ def use_all_gather(cls) -> bool: def validate_request( cls, prompt: PromptType, - params: Union[SamplingParams, PoolingParams], + params: SamplingParams | PoolingParams, processed_inputs: ProcessorInputs, ) -> None: """Raises if this request is unsupported on this platform""" @@ -217,12 +222,6 @@ def validate_request( ): raise ValueError("Torch XLA does not support per-request seed.") - @classmethod - def is_kv_cache_dtype_supported( - cls, kv_cache_dtype: str, model_config: "ModelConfig" - ) -> bool: - return True - @classmethod @torch.compile(backend="openxla") def insert_blocks_to_device( diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 2f2f3ab8b9d9..cd65cba6b492 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib import os -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import torch @@ -35,8 +36,10 @@ class XPUPlatform(Platform): device_control_env_var: str = "ZE_AFFINITY_MASK" @classmethod - def import_core_kernels(cls) -> None: - pass + def import_kernels(cls) -> None: + # Do not import vllm._C + with contextlib.suppress(ImportError): + import vllm._moe_C # noqa: F401 @classmethod def get_attn_backend_cls( @@ -44,13 +47,21 @@ def get_attn_backend_cls( selected_backend: "_Backend", head_size: int, dtype: torch.dtype, - kv_cache_dtype: Optional[str], + kv_cache_dtype: str | None, block_size: int, use_v1: bool, use_mla: bool, has_sink: bool, use_sparse, ) -> str: + from vllm.v1.attention.backends.utils import set_kv_cache_layout + + set_kv_cache_layout("NHD") + logger.info( + "Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; " + "only NHD layout is supported by XPU attention kernels." + ) + from vllm.attention.backends.registry import _Backend if use_sparse: @@ -75,22 +86,6 @@ def get_attn_backend_cls( logger.info("Using Flash Attention backend on V1 engine.") return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" - @classmethod - def is_kv_cache_dtype_supported( - cls, kv_cache_dtype: str, model_config: "ModelConfig" - ) -> bool: - """ - Check if the kv_cache_dtype is supported. - XPU only support fp8 kv cache with triton backend. - """ - if ( - envs.is_set("VLLM_ATTENTION_BACKEND") - and envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN" - ): - return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"] - - return False - @classmethod def set_device(cls, device: torch.device) -> None: """ @@ -102,7 +97,7 @@ def set_device(cls, device: torch.device) -> None: def get_device_capability( cls, device_id: int = 0, - ) -> Optional[DeviceCapability]: + ) -> DeviceCapability | None: # capacity format differs from cuda's and will cause unexpected # failure, so use None directly return None @@ -133,7 +128,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config.block_size = 64 # lazy import to avoid circular import - from vllm.config import CompilationLevel, CUDAGraphMode + from vllm.config import CompilationMode, CUDAGraphMode compilation_config = vllm_config.compilation_config if compilation_config.compile_sizes is None: @@ -144,11 +139,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: ) if vllm_config.lora_config is not None: - compilation_config.level = CompilationLevel.NO_COMPILATION + compilation_config.mode = CompilationMode.NONE # check and update parallel config parallel_config = vllm_config.parallel_config parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker" + if vllm_config.kv_transfer_config is not None: + vllm_config.kv_transfer_config.enable_permute_local_kv = True if parallel_config.distributed_executor_backend is None: if parallel_config.world_size > 1: @@ -157,7 +154,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config.distributed_executor_backend = "uni" elif parallel_config.distributed_executor_backend == "mp": # FIXME(kunshang): - # spawn needs calling `if __name__ == '__main__':`` + # spawn needs calling `if __name__ == '__main__':` # fork is not supported for xpu start new process. if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn": os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" @@ -187,13 +184,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: vllm_config.scheduler_config.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS, ) - from vllm.v1.attention.backends.utils import set_kv_cache_layout - - set_kv_cache_layout("NHD") - logger.info( - "Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; " - "only NHD layout is supported by XPU attention kernels." - ) @classmethod def support_hybrid_kv_cache(cls) -> bool: @@ -209,7 +199,7 @@ def is_pin_memory_available(cls): @classmethod def get_current_memory_usage( - cls, device: Optional[torch.types.Device] = None + cls, device: torch.types.Device | None = None ) -> float: torch.xpu.reset_peak_memory_stats(device) return torch.xpu.max_memory_allocated(device) @@ -232,8 +222,8 @@ def device_count(cls) -> int: return torch.xpu.device_count() @classmethod - def check_if_supports_dtype(cls, torch_dtype: torch.dtype): - if torch_dtype == torch.bfloat16: # noqa: SIM102 + def check_if_supports_dtype(cls, dtype: torch.dtype): + if dtype == torch.bfloat16: # noqa: SIM102 device_name = cls.get_device_name().lower() # client gpu a770 if device_name.count("a770") > 0: @@ -257,6 +247,10 @@ def insert_blocks_to_device( ) -> None: """Copy blocks from src_cache to dst_cache on XPU.""" _src_cache = src_cache[:, src_block_indices] + if _src_cache.shape[2:] != dst_cache.shape[2:]: + # To support TP_ratio, HOST KV might be initiated with HND + # while XPU device KV is with NHD + _src_cache = _src_cache.permute(0, 1, 3, 2, 4) dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device) @classmethod @@ -269,4 +263,8 @@ def swap_out_blocks_to_host( ) -> None: """Copy blocks from XPU to host (CPU).""" _src_cache = src_cache[:, src_block_indices] + if _src_cache.shape[2:] != dst_cache.shape[2:]: + # XPU device KV is with NHD while HOST KV + # might be initiated with HND for TP_ratio support + _src_cache = _src_cache.permute(0, 1, 3, 2, 4) dst_cache[:, dst_block_indices] = _src_cache.cpu() diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 0c83d49c4593..0d8988f27959 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -2,25 +2,28 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import vllm.envs as envs logger = logging.getLogger(__name__) +# Default plugins group will be loaded in all processes(process0, engine core +# process and worker processes) DEFAULT_PLUGINS_GROUP = "vllm.general_plugins" +# IO processor plugins group will be loaded in process0 only +IO_PROCESSOR_PLUGINS_GROUP = "vllm.io_processor_plugins" +# Platform plugins group will be loaded in all processes when +# `vllm.platforms.current_platform` is called and the value not initialized, +PLATFORM_PLUGINS_GROUP = "vllm.platform_plugins" # make sure one process only loads plugins once plugins_loaded = False def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]: - import sys - - if sys.version_info < (3, 10): - from importlib_metadata import entry_points - else: - from importlib.metadata import entry_points + from importlib.metadata import entry_points allowed_plugins = envs.VLLM_PLUGINS diff --git a/vllm/plugins/io_processors/__init__.py b/vllm/plugins/io_processors/__init__.py index 7a914442c4ab..b3a3b548781e 100644 --- a/vllm/plugins/io_processors/__init__.py +++ b/vllm/plugins/io_processors/__init__.py @@ -1,14 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import logging from vllm.config import VllmConfig -from vllm.plugins import load_plugins_by_group +from vllm.plugins import IO_PROCESSOR_PLUGINS_GROUP, load_plugins_by_group from vllm.plugins.io_processors.interface import IOProcessor -from vllm.utils import resolve_obj_by_qualname +from vllm.utils.import_utils import resolve_obj_by_qualname logger = logging.getLogger(__name__) @@ -39,7 +37,7 @@ def get_io_processor( # Load all installed plugin in the group multimodal_data_processor_plugins = load_plugins_by_group( - "vllm.io_processor_plugins" + IO_PROCESSOR_PLUGINS_GROUP ) loadable_plugins = {} diff --git a/vllm/plugins/io_processors/interface.py b/vllm/plugins/io_processors/interface.py index 84af40d01c43..e0488e48614d 100644 --- a/vllm/plugins/io_processors/interface.py +++ b/vllm/plugins/io_processors/interface.py @@ -3,12 +3,14 @@ from abc import ABC, abstractmethod from collections.abc import AsyncGenerator, Sequence -from typing import Any, Generic, Optional, TypeVar, Union +from typing import Any, Generic, TypeVar from vllm.config import VllmConfig from vllm.entrypoints.openai.protocol import IOProcessorResponse from vllm.inputs.data import PromptType from vllm.outputs import PoolingRequestOutput +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingParams IOProcessorInput = TypeVar("IOProcessorInput") IOProcessorOutput = TypeVar("IOProcessorOutput") @@ -22,24 +24,24 @@ def __init__(self, vllm_config: VllmConfig): def pre_process( self, prompt: IOProcessorInput, - request_id: Optional[str] = None, + request_id: str | None = None, **kwargs, - ) -> Union[PromptType, Sequence[PromptType]]: + ) -> PromptType | Sequence[PromptType]: raise NotImplementedError async def pre_process_async( self, prompt: IOProcessorInput, - request_id: Optional[str] = None, + request_id: str | None = None, **kwargs, - ) -> Union[PromptType, Sequence[PromptType]]: + ) -> PromptType | Sequence[PromptType]: return self.pre_process(prompt, request_id, **kwargs) @abstractmethod def post_process( self, model_output: Sequence[PoolingRequestOutput], - request_id: Optional[str] = None, + request_id: str | None = None, **kwargs, ) -> IOProcessorOutput: raise NotImplementedError @@ -47,7 +49,7 @@ def post_process( async def post_process_async( self, model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]], - request_id: Optional[str] = None, + request_id: str | None = None, **kwargs, ) -> IOProcessorOutput: # We cannot guarantee outputs are returned in the same order they were @@ -63,6 +65,11 @@ async def post_process_async( def parse_request(self, request: Any) -> IOProcessorInput: raise NotImplementedError + def validate_or_generate_params( + self, params: SamplingParams | PoolingParams | None = None + ) -> SamplingParams | PoolingParams: + return params or PoolingParams() + @abstractmethod def output_to_response( self, plugin_output: IOProcessorOutput diff --git a/vllm/plugins/lora_resolvers/filesystem_resolver.py b/vllm/plugins/lora_resolvers/filesystem_resolver.py index c3255af45702..8d94a673e862 100644 --- a/vllm/plugins/lora_resolvers/filesystem_resolver.py +++ b/vllm/plugins/lora_resolvers/filesystem_resolver.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json import os -from typing import Optional import vllm.envs as envs from vllm.lora.request import LoRARequest @@ -15,7 +14,7 @@ def __init__(self, lora_cache_dir: str): async def resolve_lora( self, base_model_name: str, lora_name: str - ) -> Optional[LoRARequest]: + ) -> LoRARequest | None: lora_path = os.path.join(self.lora_cache_dir, lora_name) if os.path.exists(lora_path): adapter_config_path = os.path.join( diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index f7a53503e584..090d92414465 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -10,7 +10,7 @@ from vllm.tasks import PoolingTask if TYPE_CHECKING: - from vllm.config import ModelConfig + from vllm.config import ModelConfig, PoolerConfig class PoolingParams( @@ -30,50 +30,36 @@ class PoolingParams( if model support matryoshka representation. activation: Whether to apply activation function to the classification outputs. - softmax: Whether to apply softmax to the reward outputs. """ # --8<-- [start:common-pooling-params] - truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None + truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None # --8<-- [end:common-pooling-params] ## for embeddings models # --8<-- [start:embedding-pooling-params] - dimensions: Optional[int] = None - normalize: Optional[bool] = None + dimensions: int | None = None + normalize: bool | None = None # --8<-- [end:embedding-pooling-params] ## for classification, scoring and rerank # --8<-- [start:classification-pooling-params] - activation: Optional[bool] = None + activation: bool | None = None # --8<-- [end:classification-pooling-params] - ## for reward models - softmax: Optional[bool] = None - step_tag_id: Optional[int] = None - returned_token_ids: Optional[list[int]] = None - - task: Optional[PoolingTask] = None - """Internal use only.""" + ## for step pooling models + step_tag_id: int | None = None + returned_token_ids: list[int] | None = None + ## Internal use only + task: PoolingTask | None = None requires_token_ids: bool = False - """Internal use only.""" - - extra_kwargs: Optional[dict[str, Any]] = None - """Internal use only.""" - + extra_kwargs: dict[str, Any] | None = None output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY @property def all_parameters(self) -> list[str]: - return [ - "dimensions", - "normalize", - "activation", - "softmax", - "step_tag_id", - "returned_token_ids", - ] + return ["dimensions", "normalize", "activation"] @property def valid_parameters(self): @@ -81,7 +67,8 @@ def valid_parameters(self): "embed": ["dimensions", "normalize"], "classify": ["activation"], "score": ["activation"], - "encode": ["softmax", "step_tag_id", "returned_token_ids"], + "token_embed": ["dimensions", "normalize"], + "token_classify": ["activation"], } def clone(self) -> "PoolingParams": @@ -97,10 +84,14 @@ def verify( msg = f"You cannot overwrite {self.task=!r} with {task=!r}!" raise ValueError(msg) + # plugin task uses io_processor.parse_request to verify inputs, + # skipping PoolingParams verify + if self.task == "plugin": + return + # NOTE: Task validation needs to done against the model instance, # which is not available in model config. So, it's not included # in this method - self._merge_default_parameters(model_config) self._set_default_parameters(model_config) self._verify_valid_parameters() @@ -125,8 +116,34 @@ def _merge_default_parameters( if getattr(self, k, None) is None: setattr(self, k, getattr(pooler_config, k)) + self._verify_step_pooling(pooler_config, valid_parameters) + + def _verify_step_pooling( + self, pooler_config: "PoolerConfig", valid_parameters: list[str] + ): + step_pooling_parameters = ["step_tag_id", "returned_token_ids"] + if pooler_config.pooling_type != "STEP": + invalid_parameters = [] + for k in step_pooling_parameters: + if getattr(self, k, None) is not None: + invalid_parameters.append(k) + + if invalid_parameters: + raise ValueError( + f"Task {self.task} only supports {valid_parameters} " + f"parameters, does not support " + f"{invalid_parameters} parameters" + ) + else: + for k in step_pooling_parameters: + if getattr(pooler_config, k, None) is None: + continue + + if getattr(self, k, None) is None: + setattr(self, k, getattr(pooler_config, k)) + def _set_default_parameters(self, model_config: Optional["ModelConfig"]): - if self.task == "embed": + if self.task in ["embed", "token_embed"]: if self.normalize is None: self.normalize = True @@ -150,13 +167,9 @@ def _set_default_parameters(self, model_config: Optional["ModelConfig"]): elif self.dimensions < 1: raise ValueError("Dimensions must be greater than 0") - elif self.task in ["classify", "score"]: + elif self.task in ["classify", "score", "token_classify"]: if self.activation is None: self.activation = True - - elif self.task == "encode": - if self.softmax is None: - self.softmax = True else: raise ValueError(f"Unknown pooling task: {self.task}") @@ -185,7 +198,6 @@ def __repr__(self) -> str: f"normalize={self.normalize}, " f"dimensions={self.dimensions}, " f"activation={self.activation}, " - f"softmax={self.softmax}, " f"step_tag_id={self.step_tag_id}, " f"returned_token_ids={self.returned_token_ids}, " f"requires_token_ids={self.requires_token_ids}, " diff --git a/vllm/profiler/layerwise_profile.py b/vllm/profiler/layerwise_profile.py index fea299b287f9..1c0fce702b3f 100644 --- a/vllm/profiler/layerwise_profile.py +++ b/vllm/profiler/layerwise_profile.py @@ -3,8 +3,9 @@ import copy from collections import defaultdict +from collections.abc import Callable from dataclasses import asdict, dataclass, field -from typing import Any, Callable, Optional, TypeAlias, Union +from typing import Any, Optional, TypeAlias import pandas as pd from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult @@ -62,14 +63,14 @@ class ModelStatsEntry: trace: str -StatsEntry: TypeAlias = Union[ModelStatsEntry, SummaryStatsEntry] +StatsEntry: TypeAlias = ModelStatsEntry | SummaryStatsEntry @dataclass class _StatsTreeNode: entry: StatsEntry children: list[StatsEntry] - parent: Optional[StatsEntry] + parent: StatsEntry | None @dataclass @@ -82,7 +83,7 @@ class LayerwiseProfileResults(profile): _summary_stats_tree: list[_StatsTreeNode] = field(init=False) # profile metadata - num_running_seqs: Optional[int] = None + num_running_seqs: int | None = None def __post_init__(self): self._build_correlation_map() @@ -150,7 +151,7 @@ def convert_stats_to_dict(self) -> dict[str, Any]: @staticmethod def _indent_row_names_based_on_depth( depths_rows: list[tuple[int, StatsEntry]], - indent_style: Union[Callable[[int], str], str] = " ", + indent_style: Callable[[int], str] | str = " ", ): indented_rows = [] for depth, row in depths_rows: @@ -171,7 +172,7 @@ def _build_module_tree(self): event_tree = self._kineto_results.experimental_event_tree() def _df_traversal( - event: _ProfilerEvent, curr_node: Optional[_ModuleTreeNode] = None + event: _ProfilerEvent, curr_node: _ModuleTreeNode | None = None ): # For the tensor parallel case for now only look at task 1 if event.start_tid != 1: @@ -242,7 +243,7 @@ def pct_cuda_time(cuda_time_us): def build_summary_stats_tree_df( node: _ModuleTreeNode, - parent: Optional[_StatsTreeNode] = None, + parent: _StatsTreeNode | None = None, summary_trace: tuple[str] = (), ): if event_has_module(node.event): @@ -287,7 +288,7 @@ def build_summary_stats_tree_df( self._summary_stats_tree.append(build_summary_stats_tree_df(root)) def build_model_stats_tree_df( - node: _ModuleTreeNode, parent: Optional[_StatsTreeNode] = None + node: _ModuleTreeNode, parent: _StatsTreeNode | None = None ): if event_has_module( node.event, @@ -357,7 +358,7 @@ def df_traversal(node: _StatsTreeNode, curr_json_list: list[dict]): class layerwise_profile(profile): - def __init__(self, num_running_seqs: Optional[int] = None): + def __init__(self, num_running_seqs: int | None = None): """ layerwise profile constructor. diff --git a/vllm/profiler/utils.py b/vllm/profiler/utils.py index b3607fbecde7..c95f9f4ac977 100644 --- a/vllm/profiler/utils.py +++ b/vllm/profiler/utils.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses -from typing import Callable, Union +from collections.abc import Callable from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata @@ -78,7 +78,7 @@ def _print_line(self): def indent_string( - string: str, indent: int, indent_style: Union[Callable[[int], str], str] = " " + string: str, indent: int, indent_style: Callable[[int], str] | str = " " ) -> str: if indent: if isinstance(indent_style, str): diff --git a/vllm/ray/ray_env.py b/vllm/ray/ray_env.py index a89e55bd7e4b..85623cfe5ff5 100644 --- a/vllm/ray/ray_env.py +++ b/vllm/ray/ray_env.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json import os -from typing import Optional import vllm.envs as envs from vllm.logger import init_logger @@ -32,9 +31,9 @@ def get_env_vars_to_copy( - exclude_vars: Optional[set[str]] = None, - additional_vars: Optional[set[str]] = None, - destination: Optional[str] = None, + exclude_vars: set[str] | None = None, + additional_vars: set[str] | None = None, + destination: str | None = None, ) -> set[str]: """ Get the environment variables to copy to downstream Ray actors. diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index 78d3bf35f2a3..ecee1af43902 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -4,10 +4,13 @@ from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager from .basic_parsers import BaseThinkingReasoningParser from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from .deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser +from .ernie45_reasoning_parser import Ernie45ReasoningParser from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser from .gptoss_reasoning_parser import GptOssReasoningParser from .granite_reasoning_parser import GraniteReasoningParser from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser +from .identity_reasoning_parser import IdentityReasoningParser from .mistral_reasoning_parser import MistralReasoningParser from .olmo3_reasoning_parser import Olmo3ReasoningParser from .qwen3_reasoning_parser import Qwen3ReasoningParser @@ -19,6 +22,9 @@ "BaseThinkingReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser", + "IdentityReasoningParser", + "DeepSeekV3ReasoningParser", + "Ernie45ReasoningParser", "GraniteReasoningParser", "HunyuanA13BReasoningParser", "Qwen3ReasoningParser", diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index 2d93f0702f72..ebd660ca5a84 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -1,16 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import os from abc import abstractmethod -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import cached_property -from typing import TYPE_CHECKING, Any, Callable, Union +from typing import TYPE_CHECKING, Any +from vllm.entrypoints.tool_server import ToolServer from vllm.logger import init_logger -from vllm.utils import import_from_path, is_list_of +from vllm.utils.collection_utils import is_list_of +from vllm.utils.import_utils import import_from_path if TYPE_CHECKING: from vllm.entrypoints.openai.protocol import ( @@ -78,7 +78,7 @@ def extract_content_ids(self, input_ids: list[int]) -> list[int]: def extract_reasoning_content( self, model_output: str, - request: Union[ChatCompletionRequest, ResponsesRequest], + request: ChatCompletionRequest | ResponsesRequest, ) -> tuple[str | None, str | None]: """ Extract reasoning content from a complete model-generated string. @@ -107,7 +107,7 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: """ Instance method that should be implemented for extracting reasoning from an incomplete response; for use when handling reasoning calls and @@ -116,6 +116,17 @@ def extract_reasoning_content_streaming( previously been parsed and extracted (see constructor) """ + def prepare_structured_tag( + self, + original_tag: str | None, + tool_server: ToolServer | None, + ) -> str: + """ + Instance method that is implemented for preparing the structured tag + Otherwise, None is returned + """ + return None + class ReasoningParserManager: reasoning_parsers: dict[str, type] = {} @@ -136,7 +147,7 @@ def get_reasoning_parser(cls, name: str | None) -> type[ReasoningParser]: def _register_module( cls, module: type, - module_name: Union[str, list[str]] | None = None, + module_name: str | list[str] | None = None, force: bool = True, ) -> None: if not issubclass(module, ReasoningParser): @@ -158,10 +169,10 @@ def _register_module( @classmethod def register_module( cls, - name: Union[str, list[str]] | None = None, + name: str | list[str] | None = None, force: bool = True, - module: Union[type, None] = None, - ) -> Union[type, Callable]: + module: type | None = None, + ) -> type | Callable: """ Register module with the given name or name list. it can be used as a decoder(with module as None) or normal function(with module as not diff --git a/vllm/reasoning/basic_parsers.py b/vllm/reasoning/basic_parsers.py index b4106a4f5794..621a73b2a59f 100644 --- a/vllm/reasoning/basic_parsers.py +++ b/vllm/reasoning/basic_parsers.py @@ -3,7 +3,6 @@ from abc import abstractmethod from collections.abc import Sequence -from typing import Optional, Union from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, @@ -59,7 +58,8 @@ def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs): ) def is_reasoning_end(self, input_ids: list[int]) -> bool: - return self.end_token_id in input_ids + end_token_id = self.end_token_id + return any(input_id == end_token_id for input_id in reversed(input_ids)) def extract_content_ids(self, input_ids: list[int]) -> list[int]: """ @@ -78,7 +78,7 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: """ Extract reasoning content from a delta message. Handles streaming output where previous + delta = current. @@ -134,8 +134,8 @@ def extract_reasoning_content_streaming( return DeltaMessage(content=delta_text) def extract_reasoning_content( - self, model_output: str, request: Union[ChatCompletionRequest, ResponsesRequest] - ) -> tuple[Optional[str], Optional[str]]: + self, model_output: str, request: ChatCompletionRequest | ResponsesRequest + ) -> tuple[str | None, str | None]: """ Extract reasoning content from the model output. diff --git a/vllm/reasoning/deepseek_r1_reasoning_parser.py b/vllm/reasoning/deepseek_r1_reasoning_parser.py index 264da54b4879..d5200145ea03 100644 --- a/vllm/reasoning/deepseek_r1_reasoning_parser.py +++ b/vllm/reasoning/deepseek_r1_reasoning_parser.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import Union from vllm.entrypoints.openai.protocol import DeltaMessage from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager @@ -36,7 +35,7 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: ret = super().extract_reasoning_content_streaming( previous_text, current_text, diff --git a/vllm/reasoning/deepseek_v3_reasoning_parser.py b/vllm/reasoning/deepseek_v3_reasoning_parser.py new file mode 100644 index 000000000000..7116f90a1ac0 --- /dev/null +++ b/vllm/reasoning/deepseek_v3_reasoning_parser.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage +from vllm.logger import init_logger +from vllm.reasoning import ( + DeepSeekR1ReasoningParser, + ReasoningParser, + ReasoningParserManager, +) + +from .identity_reasoning_parser import IdentityReasoningParser + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("deepseek_v3") +class DeepSeekV3ReasoningParser(ReasoningParser): + """ + V3 parser that delegates to either DeepSeekR1ReasoningParser or + IdentityReasoningParser based on `thinking` and `separate_reasoning`. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + chat_kwargs = kwargs.pop("chat_template_kwargs", {}) or {} + thinking = bool(chat_kwargs.pop("thinking", False)) + + if thinking: + self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs) + else: + self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs) + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return self._parser.is_reasoning_end(input_ids) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return self._parser.extract_content_ids(input_ids) + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[str | None, str | None]: + return self._parser.extract_reasoning_content(model_output, request) + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + return self._parser.extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) diff --git a/vllm/reasoning/ernie45_reasoning_parser.py b/vllm/reasoning/ernie45_reasoning_parser.py new file mode 100644 index 000000000000..f9d4a30398cf --- /dev/null +++ b/vllm/reasoning/ernie45_reasoning_parser.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParserManager +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("ernie45") +class Ernie45ReasoningParser(BaseThinkingReasoningParser): + """ + Reasoning parser for Ernie45 thinking model. + The Ernie45 thinking model ouput format is + abc\n</think>\n\n<response>\ndef\n</response>\n + or abc\n</think>\ndef + """ + + response_start_token: str = "<response>" + response_end_token: str = "</response>" + newline_token: str = "<0x0A>" + + @property + def start_token(self) -> str: + """The token that starts reasoning content.""" + return "<think>" + + @property + def end_token(self) -> str: + """The token that ends reasoning content.""" + return "</think>" + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction." + ) + + self.start_token_id = self.vocab.get(self.start_token) + self.end_token_id = self.vocab.get(self.end_token) + self.response_start_token_id = self.vocab.get(self.response_start_token) + self.response_end_token_id = self.vocab.get(self.response_end_token) + self.newline_token_id = self.vocab.get(self.newline_token) + + self.parser_token_ids = [self.end_token_id, self.response_end_token_id] + + if self.start_token_id is None or self.end_token_id is None: + raise RuntimeError( + "Ernie45 reasoning parser could not locate think start/end " + "tokens in the tokenizer!" + ) + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + """ + Extract reasoning content from a delta message. + Handles streaming output where previous + delta = current. + Uses token IDs for faster processing. + The Ernie45 thinking model ouput format is + abc\n</think>\n\n<response>\ndef\n</response>\n + or abc\n</think>\ndef + - 'abc' goes to reasoning_content + - 'def' goes to content + """ + # Skip single special tokens + if len(delta_token_ids) == 1 and ( + delta_token_ids[0] + in [ + self.start_token_id, + self.end_token_id, + self.response_start_token_id, + self.response_end_token_id, + ] + ): + return None + + # No <think> in previous or delta, also need to check for </think>. + # Because the model may have generated </think> without <think> + if self.end_token_id in delta_token_ids: + # </think> in delta with more tokens, + # extract reasoning content and content + think_end_index = delta_text.find(self.end_token) + reasoning_content = delta_text[:think_end_index] + content = delta_text[think_end_index + len(self.end_token) :] + content = content.lstrip("\n") + response_start_idx = content.find(self.response_start_token) + response_end_idx = content.rfind(self.response_end_token) + if response_start_idx != -1: + content = content[response_start_idx + len(self.response_start_token) :] + if response_end_idx != -1: + content = content[:response_end_idx] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) + elif self.end_token_id in previous_token_ids: + # </think> in previous, thinking content ends + content = delta_text + if self.response_start_token_id in delta_token_ids: + content = content.lstrip("\n") + response_start_idx = content.find(self.response_start_token) + content = content[response_start_idx + len(self.response_start_token) :] + # if have </response>, remove it + response_end_idx = content.rfind(self.response_end_token) + if response_end_idx != -1: + content = content[:response_end_idx] + elif self.response_end_token_id in delta_token_ids: + response_end_idx = content.rfind(self.response_end_token) + content = content[:response_end_idx] + # remove \n after </think> or </response> + if previous_token_ids[-1] in self.parser_token_ids and ( + len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id + ): + content = content.lstrip("\n") + # remove \n after </think>\n + if ( + len(previous_token_ids) > 1 + and previous_token_ids[-2] == self.end_token_id + ) and ( + len(delta_token_ids) > 0 and delta_token_ids[0] == self.newline_token_id + ): + content = content.lstrip("\n") + + return DeltaMessage(content=content if content else None) + else: + # no </think> in previous or delta, reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[str | None, str | None]: + """ + Extract reasoning content from the model output. + The Ernie45 thinking model ouput format is + abc\n</think>\n\n\n<response>\ndef\n</response>\n + or abc\n</think>\ndef + - 'abc' goes to reasoning_content + - 'def' goes to content + Returns: + tuple[Optional[str], Optional[str]]: reasoning content and content + """ + reasoning_content, content = super().extract_reasoning_content( + model_output, request + ) + if content: + start_idx = content.find(self.response_start_token) + end_idx = content.rfind(self.response_end_token) + # Simultaneously existing and in the correct order + if start_idx != -1 and end_idx != -1 and start_idx < end_idx: + content = content[start_idx + len(self.response_start_token) : end_idx] + final_content = content or None + + return reasoning_content, final_content diff --git a/vllm/reasoning/glm4_moe_reasoning_parser.py b/vllm/reasoning/glm4_moe_reasoning_parser.py index da98515c7e62..09cd43c1d555 100644 --- a/vllm/reasoning/glm4_moe_reasoning_parser.py +++ b/vllm/reasoning/glm4_moe_reasoning_parser.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import Optional, Union from transformers import PreTrainedTokenizerBase @@ -80,7 +79,7 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: """ Extract reasoning content from a delta message. Handles streaming output where previous + delta = current. @@ -137,7 +136,7 @@ def extract_reasoning_content_streaming( def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest - ) -> tuple[Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None]: """ Extract reasoning content from the model output. diff --git a/vllm/reasoning/gptoss_reasoning_parser.py b/vllm/reasoning/gptoss_reasoning_parser.py index 738c7b51694a..e6766ddcbc68 100644 --- a/vllm/reasoning/gptoss_reasoning_parser.py +++ b/vllm/reasoning/gptoss_reasoning_parser.py @@ -1,18 +1,61 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import json from collections.abc import Sequence -from typing import Optional, Union from transformers import PreTrainedTokenizerBase from vllm.entrypoints.harmony_utils import parse_chat_output from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage +from vllm.entrypoints.tool_server import ToolServer from vllm.logger import init_logger from vllm.reasoning import ReasoningParser, ReasoningParserManager logger = init_logger(__name__) +no_func_reaonsing_tag = { + "type": "structural_tag", + "format": { + "type": "triggered_tags", + "tags": [ + { + "begin": "<|channel|>analysis<|message|>", + "content": {"type": "any_text"}, + "end": "<|end|>", + } + ], + "triggers": ["<|channel|>analysis"], + "stop_after_first": False, + }, +} + + +def from_builtin_tool_to_tag(tool: str) -> list[dict]: + tag = [ + { + "begin": f"<|channel|>commentary to={tool}", + "content": {"type": "any_text"}, + "end": "<|end|>", + }, + { + "begin": f"<|channel|>analysis to={tool}", + "content": {"type": "any_text"}, + "end": "<|end|>", + }, + ] + return tag + + +def tag_with_builtin_funcs(no_func_reaonsing_tag, builtin_tool_list: list[str]) -> dict: + import copy + + new_tag = copy.deepcopy(no_func_reaonsing_tag) + new_tag["format"]["triggers"].append("<|channel|>commentary to=") + + for tool in builtin_tool_list: + new_tag["format"]["tags"].extend(from_builtin_tool_to_tag(tool)) + return new_tag + @ReasoningParserManager.register_module("openai_gptoss") class GptOssReasoningParser(ReasoningParser): @@ -53,7 +96,7 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: prev_reasoning, prev_content, _ = parse_chat_output(list(previous_token_ids)) cur_reasoning, cur_content, _ = parse_chat_output(list(current_token_ids)) reasoning_delta = None @@ -78,7 +121,37 @@ def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest, - ) -> tuple[Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None]: raise NotImplementedError( "gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used." # noqa: E501 ) + + # This function prepares the structural tag to format reasoning output + def prepare_structured_tag( + self, original_tag: str | None, tool_server: ToolServer | None + ) -> str: + if original_tag is None: + if tool_server is None: + return json.dumps(no_func_reaonsing_tag) + else: + builtin_tool_list: list[str] = [] + if tool_server.has_tool("browser"): + builtin_tool_list.append("browser") + if tool_server.has_tool("python"): + builtin_tool_list.append("python") + if tool_server.has_tool("container"): + builtin_tool_list.append("container") + + if len(builtin_tool_list) > 0: + logger.info("Builtin_tool_list: %s", builtin_tool_list) + func_tag = json.dumps( + tag_with_builtin_funcs(no_func_reaonsing_tag, builtin_tool_list) + ) + else: + logger.info("Builtin_tool_list is empty") + func_tag = json.dumps(no_func_reaonsing_tag) + + return func_tag + else: + # There is potential risk for appending the tag to the original tag + return original_tag diff --git a/vllm/reasoning/granite_reasoning_parser.py b/vllm/reasoning/granite_reasoning_parser.py index 543b202989ee..44391f8ad635 100644 --- a/vllm/reasoning/granite_reasoning_parser.py +++ b/vllm/reasoning/granite_reasoning_parser.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import Optional, Union import regex as re from transformers import PreTrainedTokenizerBase @@ -53,7 +52,7 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest - ) -> tuple[Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None]: """Extract the reasoning content & content sections, respectively. If the sequence doesn't match what we expect, i.e., the model generates something else, all content is considered non-reasoning content. @@ -82,7 +81,7 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: """Extract the reasoning content / content emitted by granite models; If the sequence doesn't match what we expect, i.e., the model generates something else, all content is considered non-reasoning content. @@ -322,7 +321,7 @@ def _get_delta_message_with_both_bounds( def _get_content_sections( self, current_text: str - ) -> tuple[Optional[str], Optional[int], Optional[str]]: + ) -> tuple[str | None, int | None, str | None]: """Parse the text to extract the reasoning content / content if we have them. diff --git a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py index 381f1b5f3466..e5cf6f399740 100644 --- a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py +++ b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import Optional, Union import regex as re from transformers import PreTrainedTokenizerBase @@ -90,7 +89,7 @@ def extract_content_ids(self, input_ids: list[int]) -> list[int]: def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest - ) -> tuple[Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None]: """Extract the reasoning content & content sections, respectively. If the sequence doesn't match what we expect, i.e., the model generates something else, all content is considered non-reasoning content. @@ -150,7 +149,7 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: """Extract content using token ID sequence state machine""" # Define sequences think_start_sequence = self.think_start_ids diff --git a/vllm/reasoning/identity_reasoning_parser.py b/vllm/reasoning/identity_reasoning_parser.py new file mode 100644 index 000000000000..f1d17a71be33 --- /dev/null +++ b/vllm/reasoning/identity_reasoning_parser.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser + +logger = init_logger(__name__) + + +class IdentityReasoningParser(ReasoningParser): + """ + Identity reasoning parser. + + This parser does not attempt to parse or strip out reasoning tokens. + It treats the entire model output as content and ignores reasoning. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction." + ) + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + # Always return True, since we never treat reasoning specially + return True + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + # Identity: return all tokens as content + return input_ids + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + # Just wrap delta_text as content, ignore reasoning + if delta_text: + return DeltaMessage(content=delta_text) + return None + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[str | None, str | None]: + # No reasoning separation: return None for reasoning_content, + # and full model_output as content + return None, model_output diff --git a/vllm/reasoning/olmo3_reasoning_parser.py b/vllm/reasoning/olmo3_reasoning_parser.py index b330e8b1fdd5..b6c26899a114 100644 --- a/vllm/reasoning/olmo3_reasoning_parser.py +++ b/vllm/reasoning/olmo3_reasoning_parser.py @@ -4,7 +4,7 @@ import dataclasses as dt import enum from collections.abc import Sequence -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING import regex as re @@ -36,7 +36,7 @@ def __len__(self): return self.end - self.start -def string_overlap(a: str, b: str) -> tuple[Optional[Indices], Optional[Indices]]: +def string_overlap(a: str, b: str) -> tuple[Indices | None, Indices | None]: """ Find the longest overlap where the end of string a matches the start of string b. @@ -90,7 +90,7 @@ class Olmo3ReasoningBuffer: # is when we switch to content state. state: Olmo3ReasoningState = Olmo3ReasoningState.REASONING - def process_buffer(self) -> Optional[DeltaMessage]: + def process_buffer(self) -> DeltaMessage | None: start_think_idx = self.buffer.find(self.think_start) if start_think_idx >= 0: @@ -142,12 +142,12 @@ def __len__(self): # is the length of the text buffer return len(self.buffer) - def add_text(self, delta_text: str) -> Optional[DeltaMessage]: + def add_text(self, delta_text: str) -> DeltaMessage | None: # we start by adding the delta text to the buffer self.buffer += delta_text # setting this to empty before starting - delta_message: Optional[DeltaMessage] = None + delta_message: DeltaMessage | None = None # we start by computing the overlap between the delta_text # and start/end of think tokens. @@ -254,8 +254,8 @@ def extract_content_ids(self, input_ids: list[int]) -> list[int]: def extract_reasoning_content( self, model_output: str, - request: Union[ChatCompletionRequest, ResponsesRequest], - ) -> tuple[Optional[str], Optional[str]]: + request: ChatCompletionRequest | ResponsesRequest, + ) -> tuple[str | None, str | None]: """Extract the reasoning content & content sections, respectively. If the sequence doesn't match what we expect, i.e., the model generates something else, all content is considered non-reasoning content. @@ -287,7 +287,7 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: """Extract content using token ID sequence state machine""" delta_message = self.buffer.add_text(delta_text) diff --git a/vllm/reasoning/qwen3_reasoning_parser.py b/vllm/reasoning/qwen3_reasoning_parser.py index 160e8633a43f..2ec06720719d 100644 --- a/vllm/reasoning/qwen3_reasoning_parser.py +++ b/vllm/reasoning/qwen3_reasoning_parser.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ResponsesRequest from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager @@ -31,8 +30,8 @@ def end_token(self) -> str: return "</think>" def extract_reasoning_content( - self, model_output: str, request: Union[ChatCompletionRequest, ResponsesRequest] - ) -> tuple[Optional[str], Optional[str]]: + self, model_output: str, request: ChatCompletionRequest | ResponsesRequest + ) -> tuple[str | None, str | None]: """ Extract reasoning content from the model output. diff --git a/vllm/reasoning/step3_reasoning_parser.py b/vllm/reasoning/step3_reasoning_parser.py index c9f580077b33..ae066d96f250 100644 --- a/vllm/reasoning/step3_reasoning_parser.py +++ b/vllm/reasoning/step3_reasoning_parser.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import Optional, Union import regex as re from transformers import PreTrainedTokenizerBase @@ -50,7 +49,7 @@ def extract_reasoning_content_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: """ Extract reasoning content from a delta message. Handles streaming output where previous + delta = current. @@ -81,7 +80,7 @@ def extract_reasoning_content_streaming( def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest - ) -> tuple[Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None]: # Check if the model output contains the </think> token if self.think_end_token not in model_output: # If no </think> token, everything is reasoning content diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index a1ff4e5ff63b..4b2a3bc4dbaa 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -7,7 +7,7 @@ from dataclasses import field from enum import Enum, IntEnum from functools import cached_property -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any import msgspec from pydantic.dataclasses import dataclass @@ -32,19 +32,19 @@ class SamplingType(IntEnum): @dataclass class StructuredOutputsParams: # One of these fields will be used to build a logit processor. - json: Optional[Union[str, dict]] = None - regex: Optional[str] = None - choice: Optional[list[str]] = None - grammar: Optional[str] = None - json_object: Optional[bool] = None + json: str | dict | None = None + regex: str | None = None + choice: list[str] | None = None + grammar: str | None = None + json_object: bool | None = None # These are other options that can be set. disable_fallback: bool = False disable_any_whitespace: bool = False disable_additional_properties: bool = False - whitespace_pattern: Optional[str] = None - structural_tag: Optional[str] = None + whitespace_pattern: str | None = None + structural_tag: str | None = None - _backend: Optional[str] = field(default=None, init=False) + _backend: str | None = field(default=None, init=False) """CAUTION: Should only be set by Processor._validate_structured_output""" _backend_was_auto: bool = field(default=False, init=False) """CAUTION: Should only be set by Processor._validate_structured_output""" @@ -58,6 +58,7 @@ def __post_init__(self): self.choice is not None, self.grammar is not None, self.json_object is not None, + self.structural_tag is not None, ] ) if count > 1: @@ -66,6 +67,37 @@ def __post_init__(self): f"but multiple are specified: {self.__dict__}" ) + def all_constraints_none(self) -> bool: + """ + Returns True if all structured-output constraint fields are None. + """ + return all( + getattr(self, field) is None + for field in ( + "json", + "regex", + "choice", + "grammar", + "json_object", + "structural_tag", + ) + ) + + def all_non_structural_tag_constraints_none(self) -> bool: + """ + Returns True if all structured-output constraint fields are None. + """ + return all( + getattr(self, field) is None + for field in ( + "json", + "regex", + "choice", + "grammar", + "json_object", + ) + ) + @dataclass class GuidedDecodingParams(StructuredOutputsParams): @@ -110,12 +142,12 @@ class SamplingParams( are generated and streamed cumulatively per request. To see all `n` outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY` in `SamplingParams`.""" - best_of: Optional[int] = None + best_of: int | None = None """Number of output sequences that are generated from the prompt. From these `best_of` sequences, the top `n` sequences are returned. `best_of` must be greater than or equal to `n`. By default, `best_of` is set to `n`. Warning, this is only supported in V0.""" - _real_n: Optional[int] = None + _real_n: int | None = None presence_penalty: float = 0.0 """Penalizes new tokens based on whether they appear in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 @@ -142,24 +174,24 @@ class SamplingParams( """Represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in [0, 1]. Set to 0 to disable this.""" - seed: Optional[int] = None + seed: int | None = None """Random seed to use for the generation.""" - stop: Optional[Union[str, list[str]]] = None + stop: str | list[str] | None = None """String(s) that stop the generation when they are generated. The returned output will not contain the stop strings.""" - stop_token_ids: Optional[list[int]] = None + stop_token_ids: list[int] | None = None """Token IDs that stop the generation when they are generated. The returned output will contain the stop tokens unless the stop tokens are special tokens.""" ignore_eos: bool = False """Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.""" - max_tokens: Optional[int] = 16 + max_tokens: int | None = 16 """Maximum number of tokens to generate per output sequence.""" min_tokens: int = 0 """Minimum number of tokens to generate per output sequence before EOS or `stop_token_ids` can be generated""" - logprobs: Optional[int] = None + logprobs: int | None = None """Number of log probabilities to return per output token. When set to `None`, no probability is returned. If set to a non-`None` value, the result includes the log probabilities of the specified number of most @@ -167,7 +199,7 @@ class SamplingParams( follows the OpenAI API: The API will always return the log probability of the sampled token, so there may be up to `logprobs+1` elements in the response. When set to -1, return all `vocab_size` log probabilities.""" - prompt_logprobs: Optional[int] = None + prompt_logprobs: int | None = None """Number of log probabilities to return per prompt token. When set to -1, return all `vocab_size` log probabilities.""" # NOTE: This parameter is only exposed at the engine level for now. @@ -179,14 +211,14 @@ class SamplingParams( """Whether to skip special tokens in the output.""" spaces_between_special_tokens: bool = True """Whether to add spaces between special tokens in the output.""" - # Optional[list[LogitsProcessor]] type. We use Any here because - # Optional[list[LogitsProcessor]] type is not supported by msgspec. - logits_processors: Optional[Any] = None + # `list[LogitsProcessor] | None` type. We use Any here because + # `list[LogitsProcessor] | None` type is not supported by msgspec. + logits_processors: Any | None = None """Functions that modify logits based on previously generated tokens, and optionally prompt tokens as a first argument.""" include_stop_str_in_output: bool = False """Whether to include the stop strings in output text.""" - truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None + truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None """If set to -1, will use the truncation size supported by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is disabled.""" @@ -198,60 +230,60 @@ class SamplingParams( _all_stop_token_ids: set[int] = msgspec.field(default_factory=set) # Fields used to construct logits processors - structured_outputs: Optional[StructuredOutputsParams] = None + structured_outputs: StructuredOutputsParams | None = None """Parameters for configuring structured outputs.""" - guided_decoding: Optional[GuidedDecodingParams] = None + guided_decoding: GuidedDecodingParams | None = None """Deprecated alias for structured_outputs.""" - logit_bias: Optional[dict[int, float]] = None + logit_bias: dict[int, float] | None = None """If provided, the engine will construct a logits processor that applies these logit biases.""" - allowed_token_ids: Optional[list[int]] = None + allowed_token_ids: list[int] | None = None """If provided, the engine will construct a logits processor which only retains scores for the given token ids.""" - extra_args: Optional[dict[str, Any]] = None + extra_args: dict[str, Any] | None = None """Arbitrary additional args, that can be used by custom sampling implementations, plugins, etc. Not used by any in-tree sampling implementations.""" # Fields used for bad words - bad_words: Optional[list[str]] = None + bad_words: list[str] | None = None """Words that are not allowed to be generated. More precisely, only the last token of a corresponding token sequence is not allowed when the next generated token can complete the sequence.""" - _bad_words_token_ids: Optional[list[list[int]]] = None + _bad_words_token_ids: list[list[int]] | None = None @staticmethod def from_optional( - n: Optional[int] = 1, - best_of: Optional[int] = None, - presence_penalty: Optional[float] = 0.0, - frequency_penalty: Optional[float] = 0.0, - repetition_penalty: Optional[float] = 1.0, - temperature: Optional[float] = 1.0, - top_p: Optional[float] = 1.0, + n: int | None = 1, + best_of: int | None = None, + presence_penalty: float | None = 0.0, + frequency_penalty: float | None = 0.0, + repetition_penalty: float | None = 1.0, + temperature: float | None = 1.0, + top_p: float | None = 1.0, top_k: int = 0, min_p: float = 0.0, - seed: Optional[int] = None, - stop: Optional[Union[str, list[str]]] = None, - stop_token_ids: Optional[list[int]] = None, - bad_words: Optional[list[str]] = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stop_token_ids: list[int] | None = None, + bad_words: list[str] | None = None, include_stop_str_in_output: bool = False, ignore_eos: bool = False, - max_tokens: Optional[int] = 16, + max_tokens: int | None = 16, min_tokens: int = 0, - logprobs: Optional[int] = None, - prompt_logprobs: Optional[int] = None, + logprobs: int | None = None, + prompt_logprobs: int | None = None, detokenize: bool = True, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, - logits_processors: Optional[list[LogitsProcessor]] = None, - truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None, + logits_processors: list[LogitsProcessor] | None = None, + truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, - structured_outputs: Optional[StructuredOutputsParams] = None, - guided_decoding: Optional[GuidedDecodingParams] = None, - logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None, - allowed_token_ids: Optional[list[int]] = None, - extra_args: Optional[dict[str, Any]] = None, + structured_outputs: StructuredOutputsParams | None = None, + guided_decoding: GuidedDecodingParams | None = None, + logit_bias: dict[int, float] | dict[str, float] | None = None, + allowed_token_ids: list[int] | None = None, + extra_args: dict[str, Any] | None = None, ) -> "SamplingParams": if logit_bias is not None: # Convert token_id to integer @@ -306,10 +338,10 @@ def from_optional( ) def __post_init__(self) -> None: - # how we deal with `best_of``: - # if `best_of`` is not set, we default to `n`; - # if `best_of`` is set, we set `n`` to `best_of`, - # and set `_real_n`` to the original `n`. + # how we deal with `best_of`: + # if `best_of` is not set, we default to `n`; + # if `best_of` is set, we set `n` to `best_of`, + # and set `_real_n` to the original `n`. # when we return the result, we will check # if we need to return `n` or `_real_n` results if self.best_of: @@ -483,7 +515,7 @@ def _verify_greedy_sampling(self) -> None: def update_from_generation_config( self, generation_config: dict[str, Any], - model_eos_token_id: Optional[int] = None, + model_eos_token_id: int | None = None, ) -> None: """Update if there are non-default values from generation_config""" @@ -559,7 +591,7 @@ def all_stop_token_ids(self) -> set[int]: return self._all_stop_token_ids @property - def bad_words_token_ids(self) -> Optional[list[list[int]]]: + def bad_words_token_ids(self) -> list[list[int]] | None: # For internal use only. Backward compatibility not guaranteed return self._bad_words_token_ids diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py index fd25d198bf1a..05760f3f8299 100644 --- a/vllm/scalar_type.py +++ b/vllm/scalar_type.py @@ -5,7 +5,6 @@ import struct from dataclasses import dataclass from enum import Enum -from typing import Optional, Union _SCALAR_TYPES_ID_MAP = {} @@ -105,7 +104,7 @@ def _floating_point_max(self) -> float: double_raw = self._floating_point_max_int() return struct.unpack("!d", struct.pack("!Q", double_raw))[0] - def _raw_max(self) -> Union[int, float]: + def _raw_max(self) -> int | float: if self.is_floating_point(): return self._floating_point_max() else: @@ -114,7 +113,7 @@ def _raw_max(self) -> Union[int, float]: ) return (1 << self.mantissa) - 1 - def _raw_min(self) -> Union[int, float]: + def _raw_min(self) -> int | float: if self.is_floating_point(): assert self.is_signed(), ( "We currently assume all floating point types are signed" @@ -168,14 +167,14 @@ def or_and_advance(member, bit_width): def size_bits(self) -> int: return self.exponent + self.mantissa + int(self.signed) - def min(self) -> Union[int, float]: + def min(self) -> int | float: """ Min representable value for this scalar type. (accounting for bias if there is one) """ return self._raw_min() - self.bias - def max(self) -> Union[int, float]: + def max(self) -> int | float: """ Max representable value for this scalar type. (accounting for bias if there is one) @@ -265,14 +264,14 @@ def __len__(self) -> int: # @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType": + def int_(cls, size_bits: int, bias: int | None) -> "ScalarType": "Create a signed integer scalar type (size_bits includes sign-bit)." ret = cls(0, size_bits - 1, True, bias if bias else 0) ret.id # noqa B018: make sure the id is cached return ret @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType": + def uint(cls, size_bits: int, bias: int | None) -> "ScalarType": """Create an unsigned integer scalar type.""" ret = cls(0, size_bits, False, bias if bias else 0) ret.id # noqa B018: make sure the id is cached diff --git a/vllm/sequence.py b/vllm/sequence.py index 7682b7f58305..6bcc94ad5c62 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -3,9 +3,8 @@ """Sequence and its related classes.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any -import msgspec import torch if TYPE_CHECKING: @@ -39,13 +38,13 @@ class RequestMetrics: arrival_time: float last_token_time: float - first_scheduled_time: Optional[float] - first_token_time: Optional[float] - time_in_queue: Optional[float] - finished_time: Optional[float] = None - scheduler_time: Optional[float] = None - model_forward_time: Optional[float] = None - model_execute_time: Optional[float] = None + first_scheduled_time: float | None + first_token_time: float | None + time_in_queue: float | None + finished_time: float | None = None + scheduler_time: float | None = None + model_forward_time: float | None = None + model_execute_time: float | None = None # cannot use msgspec.Struct here because Dynamo does not support it @@ -59,7 +58,7 @@ class IntermediateTensors: """ tensors: dict[str, torch.Tensor] - kv_connector_output: Optional[KVConnectorOutput] + kv_connector_output: KVConnectorOutput | None def __init__(self, tensors): # manually define this function, so that @@ -68,7 +67,7 @@ def __init__(self, tensors): # a string, and we will lose the information about the source file. self.tensors = tensors - def __getitem__(self, key: Union[str, slice]): + def __getitem__(self, key: str | slice): if isinstance(key, str): return self.tensors[key] elif isinstance(key, slice): @@ -92,12 +91,3 @@ def __eq__(self, other: object): def __repr__(self) -> str: return f"IntermediateTensors(tensors={self.tensors})" - - -class ExecuteModelRequest( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, -): # type: ignore[call-arg] - # Placeholder. Remove. - pass diff --git a/vllm/tasks.py b/vllm/tasks.py index 85c5c6e43620..b02cde74c12a 100644 --- a/vllm/tasks.py +++ b/vllm/tasks.py @@ -5,7 +5,9 @@ GenerationTask = Literal["generate", "transcription"] GENERATION_TASKS = get_args(GenerationTask) -PoolingTask = Literal["encode", "embed", "classify", "score"] +PoolingTask = Literal[ + "embed", "classify", "score", "token_embed", "token_classify", "plugin" +] POOLING_TASKS = get_args(PoolingTask) SupportedTask = Literal[GenerationTask, PoolingTask] diff --git a/vllm/tracing.py b/vllm/tracing.py index c9b595999fc7..01bbebf35cfc 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -3,17 +3,16 @@ import os from collections.abc import Mapping -from typing import Optional from vllm.logger import init_logger -from vllm.utils import run_once +from vllm.utils.func_utils import run_once TRACE_HEADERS = ["traceparent", "tracestate"] logger = init_logger(__name__) _is_otel_imported = False -otel_import_error_traceback: Optional[str] = None +otel_import_error_traceback: str | None = None try: from opentelemetry.context.context import Context from opentelemetry.sdk.environment_variables import ( @@ -55,7 +54,7 @@ def is_otel_available() -> bool: def init_tracer( instrumenting_module_name: str, otlp_traces_endpoint: str -) -> Optional[Tracer]: +) -> Tracer | None: if not is_otel_available(): raise ValueError( "OpenTelemetry is not available. Unable to initialize " @@ -88,7 +87,7 @@ def get_span_exporter(endpoint): return OTLPSpanExporter(endpoint=endpoint) -def extract_trace_context(headers: Optional[Mapping[str, str]]) -> Optional[Context]: +def extract_trace_context(headers: Mapping[str, str] | None) -> Context | None: if is_otel_available(): headers = headers or {} return TraceContextTextMapPropagator().extract(headers) diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py index b8d0cd8d2f20..3bdbe1d0a67b 100644 --- a/vllm/transformers_utils/chat_templates/registry.py +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable from pathlib import Path -from typing import Callable, Optional, Union +from typing import TypeAlias from vllm.logger import init_logger @@ -9,17 +10,17 @@ CHAT_TEMPLATES_DIR = Path(__file__).parent -ChatTemplatePath = Union[Path, Callable[[str], Optional[Path]]] +ChatTemplatePath: TypeAlias = Path | Callable[[str], Path | None] -def _get_qwen_chat_template_fallback(tokenizer_name_or_path: str) -> Optional[Path]: +def _get_qwen_chat_template_fallback(tokenizer_name_or_path: str) -> Path | None: if tokenizer_name_or_path.endswith("-Chat"): return CHAT_TEMPLATES_DIR / "template_chatml.jinja" return CHAT_TEMPLATES_DIR / "template_basic.jinja" -def _get_minicpmv_chat_template_fallback(tokenizer_name_or_path: str) -> Optional[Path]: +def _get_minicpmv_chat_template_fallback(tokenizer_name_or_path: str) -> Path | None: # MiniCPM-V-4.5 version uses a dedicated template if "4.5" in tokenizer_name_or_path or "4_5" in tokenizer_name_or_path: return CHAT_TEMPLATES_DIR / "template_minicpmv45.jinja" @@ -30,13 +31,15 @@ def _get_minicpmv_chat_template_fallback(tokenizer_name_or_path: str) -> Optiona _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja", - "clip": CHAT_TEMPLATES_DIR / "template_basic.jinja", "chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja", + "clip": CHAT_TEMPLATES_DIR / "template_basic.jinja", + "deepseek_ocr": CHAT_TEMPLATES_DIR / "template_deepseek_ocr.jinja", "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", "fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja", "minicpmv": _get_minicpmv_chat_template_fallback, "paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja", "qwen": _get_qwen_chat_template_fallback, + "siglip": CHAT_TEMPLATES_DIR / "template_basic.jinja", } @@ -58,7 +61,7 @@ def register_chat_template_fallback_path( def get_chat_template_fallback_path( model_type: str, tokenizer_name_or_path: str, -) -> Optional[Path]: +) -> Path | None: chat_template = _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK.get(model_type) if callable(chat_template): chat_template = chat_template(tokenizer_name_or_path) diff --git a/vllm/transformers_utils/chat_templates/template_deepseek_ocr.jinja b/vllm/transformers_utils/chat_templates/template_deepseek_ocr.jinja new file mode 100644 index 000000000000..287abe358642 --- /dev/null +++ b/vllm/transformers_utils/chat_templates/template_deepseek_ocr.jinja @@ -0,0 +1,14 @@ +{%- if messages[0]['role'] == 'system' -%} + {%- set system_message = messages[0]['content'] -%} + {%- set messages = messages[1:] -%} +{%- else -%} + {% set system_message = '' -%} +{%- endif -%} + +{{ bos_token + system_message }} +{%- for message in messages -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {%- endif -%} + {{ message['content'] }} +{%- endfor -%} diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 87bbe73d834a..7802cece6075 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -4,10 +4,11 @@ import json import os import time +from collections.abc import Callable from dataclasses import asdict from functools import cache, partial from pathlib import Path -from typing import Any, Callable, Literal, Optional, TypeVar, Union +from typing import Any, Literal, TypeVar import huggingface_hub from huggingface_hub import ( @@ -47,7 +48,7 @@ logger = init_logger(__name__) -def _get_hf_token() -> Optional[str]: +def _get_hf_token() -> str | None: """ Get the HuggingFace token from environment variable. @@ -74,6 +75,7 @@ def __getitem__(self, key): deepseek_vl_v2="DeepseekVLV2Config", deepseek_v3="DeepseekV3Config", deepseek_v32="DeepseekV3Config", + flex_olmo="FlexOlmoConfig", kimi_vl="KimiVLConfig", Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config", RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct) @@ -107,10 +109,10 @@ def __getitem__(self, key): class HFConfigParser(ConfigParserBase): def parse( self, - model: Union[str, Path], + model: str | Path, trust_remote_code: bool, - revision: Optional[str] = None, - code_revision: Optional[str] = None, + revision: str | None = None, + code_revision: str | None = None, **kwargs, ) -> tuple[dict, PretrainedConfig]: kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE @@ -172,10 +174,10 @@ def parse( class MistralConfigParser(ConfigParserBase): def parse( self, - model: Union[str, Path], + model: str | Path, trust_remote_code: bool, - revision: Optional[str] = None, - code_revision: Optional[str] = None, + revision: str | None = None, + code_revision: str | None = None, **kwargs, ) -> tuple[dict, PretrainedConfig]: # This function loads a params.json config which @@ -246,8 +248,8 @@ def register_config_parser(config_format: str): ... self, ... model: Union[str, Path], ... trust_remote_code: bool, - ... revision: Optional[str] = None, - ... code_revision: Optional[str] = None, + ... revision: str | None = None, + ... code_revision: str | None = None, ... **kwargs, ... ) -> tuple[dict, PretrainedConfig]: ... raise NotImplementedError @@ -309,9 +311,9 @@ def with_retry( def list_repo_files( repo_id: str, *, - revision: Optional[str] = None, - repo_type: Optional[str] = None, - token: Union[str, bool, None] = None, + revision: str | None = None, + repo_type: str | None = None, + token: str | bool | None = None, ) -> list[str]: def lookup_files() -> list[str]: # directly list files if model is local @@ -347,9 +349,9 @@ def file_exists( repo_id: str, file_name: str, *, - repo_type: Optional[str] = None, - revision: Optional[str] = None, - token: Union[str, bool, None] = None, + repo_type: str | None = None, + revision: str | None = None, + token: str | bool | None = None, ) -> bool: file_list = list_repo_files( repo_id, repo_type=repo_type, revision=revision, token=token @@ -359,7 +361,7 @@ def file_exists( # In offline mode the result can be a false negative def file_or_path_exists( - model: Union[str, Path], config_name: str, revision: Optional[str] + model: str | Path, config_name: str, revision: str | None ) -> bool: if (local_path := Path(model)).exists(): return (local_path / config_name).is_file() @@ -492,10 +494,10 @@ def maybe_override_with_speculators( model: str, tokenizer: str, trust_remote_code: bool, - revision: Optional[str] = None, - vllm_speculative_config: Optional[dict[str, Any]] = None, + revision: str | None = None, + vllm_speculative_config: dict[str, Any] | None = None, **kwargs, -) -> tuple[str, str, Optional[dict[str, Any]]]: +) -> tuple[str, str, dict[str, Any] | None]: """ Resolve model configuration when speculators are detected. @@ -550,13 +552,13 @@ def maybe_override_with_speculators( def get_config( - model: Union[str, Path], + model: str | Path, trust_remote_code: bool, - revision: Optional[str] = None, - code_revision: Optional[str] = None, - config_format: Union[str, ConfigFormat] = "auto", - hf_overrides_kw: Optional[dict[str, Any]] = None, - hf_overrides_fn: Optional[Callable[[PretrainedConfig], PretrainedConfig]] = None, + revision: str | None = None, + code_revision: str | None = None, + config_format: str | ConfigFormat = "auto", + hf_overrides_kw: dict[str, Any] | None = None, + hf_overrides_fn: Callable[[PretrainedConfig], PretrainedConfig] | None = None, **kwargs, ) -> PretrainedConfig: # Separate model folder from file path for GGUF models @@ -668,8 +670,8 @@ def get_config( def try_get_local_file( - model: Union[str, Path], file_name: str, revision: Optional[str] = "main" -) -> Optional[Path]: + model: str | Path, file_name: str, revision: str | None = "main" +) -> Path | None: file_path = Path(model) / file_name if file_path.is_file(): return file_path @@ -686,7 +688,7 @@ def try_get_local_file( def get_hf_file_to_dict( - file_name: str, model: Union[str, Path], revision: Optional[str] = "main" + file_name: str, model: str | Path, revision: str | None = "main" ): """ Downloads a file from the Hugging Face Hub and returns @@ -734,7 +736,7 @@ def get_hf_file_to_dict( @cache -def get_pooling_config(model: str, revision: Optional[str] = "main") -> Optional[dict]: +def get_pooling_config(model: str, revision: str | None = "main") -> dict | None: """ This function gets the pooling and normalize config from the model - only applies to @@ -798,7 +800,7 @@ def get_pooling_config(model: str, revision: Optional[str] = "main") -> Optional return None -def get_pooling_config_name(pooling_name: str) -> Union[str, None]: +def get_pooling_config_name(pooling_name: str) -> str | None: if "pooling_mode_" in pooling_name: pooling_name = pooling_name.replace("pooling_mode_", "") @@ -819,7 +821,7 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]: @cache def get_sentence_transformer_tokenizer_config( - model: Union[str, Path], revision: Optional[str] = "main" + model: str | Path, revision: str | None = "main" ): """ Returns the tokenization configuration dictionary for a @@ -941,7 +943,7 @@ def _reduce_config(config: VllmConfig): cloudpickle.register_pickle_by_value(transformers_modules) # ray vendors its own version of cloudpickle - from vllm.executor.ray_utils import ray + from vllm.v1.executor.ray_utils import ray if ray: ray.cloudpickle.register_pickle_by_value(transformers_modules) @@ -957,9 +959,9 @@ def _reduce_config(config: VllmConfig): def get_hf_image_processor_config( - model: Union[str, Path], - hf_token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, + model: str | Path, + hf_token: bool | str | None = None, + revision: str | None = None, **kwargs, ) -> dict[str, Any]: # ModelScope does not provide an interface for image_processor @@ -991,9 +993,9 @@ def get_hf_text_config(config: PretrainedConfig): def try_get_generation_config( model: str, trust_remote_code: bool, - revision: Optional[str] = None, - config_format: Union[str, ConfigFormat] = "auto", -) -> Optional[GenerationConfig]: + revision: str | None = None, + config_format: str | ConfigFormat = "auto", +) -> GenerationConfig | None: try: return GenerationConfig.from_pretrained( model, @@ -1015,7 +1017,7 @@ def try_get_generation_config( def try_get_safetensors_metadata( model: str, *, - revision: Optional[str] = None, + revision: str | None = None, ): get_safetensors_metadata_partial = partial( get_safetensors_metadata, @@ -1033,10 +1035,10 @@ def try_get_safetensors_metadata( def try_get_tokenizer_config( - pretrained_model_name_or_path: Union[str, os.PathLike], + pretrained_model_name_or_path: str | os.PathLike, trust_remote_code: bool, - revision: Optional[str] = None, -) -> Optional[dict[str, Any]]: + revision: str | None = None, +) -> dict[str, Any] | None: try: return get_tokenizer_config( pretrained_model_name_or_path, @@ -1047,10 +1049,44 @@ def try_get_tokenizer_config( return None +@cache +def try_get_dense_modules( + model: str | Path, + revision: str | None = None, +) -> list[dict[str, Any]] | None: + try: + modules = get_hf_file_to_dict("modules.json", model, revision) + if not modules: + return None + + if isinstance(modules, dict): + modules = modules.get("modules", []) + + dense_modules = [ + m for m in modules if m.get("type") == "sentence_transformers.models.Dense" + ] + if not dense_modules: + return None + + layer_configs = [] + for module in dense_modules: + folder = module.get("path", "") + + config_path = f"{folder}/config.json" if folder else "config.json" + layer_config = get_hf_file_to_dict(config_path, model, revision) + if not layer_config: + continue + layer_config["folder"] = folder + layer_configs.append(layer_config) + return layer_configs + except Exception: + return None + + def get_safetensors_params_metadata( model: str, *, - revision: Optional[str] = None, + revision: str | None = None, ) -> dict[str, Any]: """ Get the safetensors metadata for remote model repository. @@ -1111,7 +1147,7 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int: return max_position_embeddings -def get_model_path(model: Union[str, Path], revision: Optional[str] = None): +def get_model_path(model: str | Path, revision: str | None = None): if os.path.exists(model): return model assert huggingface_hub.constants.HF_HUB_OFFLINE @@ -1131,8 +1167,8 @@ def get_model_path(model: Union[str, Path], revision: Optional[str] = None): def get_hf_file_bytes( - file_name: str, model: Union[str, Path], revision: Optional[str] = "main" -) -> Optional[bytes]: + file_name: str, model: str | Path, revision: str | None = "main" +) -> bytes | None: """Get file contents from HuggingFace repository as bytes.""" file_path = try_get_local_file(model=model, file_name=file_name, revision=revision) diff --git a/vllm/transformers_utils/config_parser_base.py b/vllm/transformers_utils/config_parser_base.py index 0e1c49b428b0..79d47ff56042 100644 --- a/vllm/transformers_utils/config_parser_base.py +++ b/vllm/transformers_utils/config_parser_base.py @@ -3,7 +3,6 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Optional, Union from transformers import PretrainedConfig @@ -12,10 +11,10 @@ class ConfigParserBase(ABC): @abstractmethod def parse( self, - model: Union[str, Path], + model: str | Path, trust_remote_code: bool, - revision: Optional[str] = None, - code_revision: Optional[str] = None, + revision: str | None = None, + code_revision: str | None = None, **kwargs, ) -> tuple[dict, PretrainedConfig]: raise NotImplementedError diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 6917123ce662..befe9cdae76a 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -17,6 +17,7 @@ # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig +from vllm.transformers_utils.configs.flex_olmo import FlexOlmoConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig from vllm.transformers_utils.configs.lfm2_moe import Lfm2MoeConfig @@ -45,6 +46,7 @@ "DeepseekV3Config", "DotsOCRConfig", "EAGLEConfig", + "FlexOlmoConfig", "RWConfig", "JAISConfig", "Lfm2MoeConfig", diff --git a/vllm/transformers_utils/configs/dotsocr.py b/vllm/transformers_utils/configs/dotsocr.py index 446693b9a32e..1e42cb2fd859 100644 --- a/vllm/transformers_utils/configs/dotsocr.py +++ b/vllm/transformers_utils/configs/dotsocr.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Any from transformers.configuration_utils import PretrainedConfig from transformers.models.qwen2 import Qwen2Config @@ -57,7 +57,7 @@ def __init__( self, image_token_id=151665, video_token_id=151656, - vision_config: Optional[dict] = None, + vision_config: dict | None = None, *args, **kwargs, ): diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 6e18513d1234..4da877f9e81f 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import Optional, Union from transformers import AutoConfig, PretrainedConfig @@ -14,12 +13,12 @@ class EAGLEConfig(PretrainedConfig): def __init__( self, - model: Union[PretrainedConfig, dict, None] = None, - truncated_vocab_size: Optional[int] = None, - method: Optional[str] = "eagle", + model: PretrainedConfig | dict | None = None, + truncated_vocab_size: int | None = None, + method: str | None = "eagle", **kwargs, ): - model_config: Union[PretrainedConfig, DeepseekV2Config, None] + model_config: PretrainedConfig | DeepseekV2Config | None if isinstance(model, dict): archs = model.get("architectures", []) target_archs = ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"] @@ -84,7 +83,7 @@ def __init__( @classmethod def from_pretrained( cls, - pretrained_model_name_or_path: Union[str, os.PathLike], + pretrained_model_name_or_path: str | os.PathLike, **kwargs, ) -> "EAGLEConfig": config_dict, kwargs = cls.get_config_dict( diff --git a/vllm/transformers_utils/configs/flex_olmo.py b/vllm/transformers_utils/configs/flex_olmo.py new file mode 100644 index 000000000000..1f2f4d446288 --- /dev/null +++ b/vllm/transformers_utils/configs/flex_olmo.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from transformers.configuration_utils import PretrainedConfig + + +class FlexOlmoConfig(PretrainedConfig): + model_type = "flex_olmo" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=100352, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + pad_token_id=100277, + bos_token_id=None, + eos_token_id=100257, + tie_word_embeddings=False, + rope_theta=500000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + num_experts_per_tok=5, + num_experts=7, + output_router_logits=False, + router_aux_loss_coef=0.01, + norm_topk_prob=False, + **kwargs, + ): + if "architectures" not in kwargs: + kwargs["architectures"] = ["FlexOlmoForCausalLM"] + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.norm_topk_prob = norm_topk_prob + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] diff --git a/vllm/transformers_utils/configs/kimi_vl.py b/vllm/transformers_utils/configs/kimi_vl.py index 89a8878465b6..e8c19d0ec2ff 100644 --- a/vllm/transformers_utils/configs/kimi_vl.py +++ b/vllm/transformers_utils/configs/kimi_vl.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py -from typing import Optional, Union from transformers.configuration_utils import PretrainedConfig @@ -14,8 +13,8 @@ class KimiVLConfig(PretrainedConfig): def __init__( self, - vision_config: Optional[Union[dict, MoonViTConfig]] = None, - text_config: Optional[Union[dict, DeepseekV2Config]] = None, + vision_config: dict | MoonViTConfig | None = None, + text_config: dict | DeepseekV2Config | None = None, ignore_index: int = -100, media_placeholder_token_id: int = 163605, pad_token_id: int = 0, diff --git a/vllm/transformers_utils/configs/lfm2_moe.py b/vllm/transformers_utils/configs/lfm2_moe.py index 7d17c2b4f74c..37c038e12db8 100644 --- a/vllm/transformers_utils/configs/lfm2_moe.py +++ b/vllm/transformers_utils/configs/lfm2_moe.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional from transformers.configuration_utils import PretrainedConfig @@ -115,7 +114,7 @@ def __init__( use_expert_bias: bool = True, routed_scaling_factor: float = 1.0, norm_topk_prob: bool = True, - layer_types: Optional[list[str]] = None, + layer_types: list[str] | None = None, **kwargs, ): self.vocab_size = vocab_size diff --git a/vllm/transformers_utils/configs/medusa.py b/vllm/transformers_utils/configs/medusa.py index 7dcfd0cf26ae..bfa0f30e8961 100644 --- a/vllm/transformers_utils/configs/medusa.py +++ b/vllm/transformers_utils/configs/medusa.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import Optional, Union from transformers import PretrainedConfig @@ -18,7 +17,7 @@ def __init__( num_hidden_layers: int = 1, max_paths: int = 64, topk: int = 10, - truncated_vocab_size: Optional[int] = None, + truncated_vocab_size: int | None = None, **kwargs, ): self.hidden_size = hidden_size @@ -39,7 +38,7 @@ def __init__( @classmethod def from_pretrained( cls, - pretrained_model_name_or_path: Union[str, os.PathLike], + pretrained_model_name_or_path: str | os.PathLike, **kwargs, ) -> "MedusaConfig": config_dict, kwargs = cls.get_config_dict( diff --git a/vllm/transformers_utils/configs/midashenglm.py b/vllm/transformers_utils/configs/midashenglm.py index 5c9e72be8ebf..e49bd26b2b00 100644 --- a/vllm/transformers_utils/configs/midashenglm.py +++ b/vllm/transformers_utils/configs/midashenglm.py @@ -21,7 +21,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union from transformers import PretrainedConfig from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( @@ -36,15 +35,15 @@ def __init__( self, embed_dim: int = 768, outputdim: int = 527, - patch_size: Union[int, tuple[int, int]] = 16, - patch_stride: Union[int, tuple[int, int]] = 16, + patch_size: int | tuple[int, int] = 16, + patch_stride: int | tuple[int, int] = 16, input_channels: int = 1, target_length: int = 1012, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, - init_values: Optional[float] = None, + init_values: float | None = None, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, f_min: float = 0.0, @@ -86,10 +85,10 @@ class MiDashengLMConfig(PretrainedConfig): def __init__( self, - audio_encoder_config: Optional[dict] = None, + audio_encoder_config: dict | None = None, subsample_factor: int = 5, - text_config: Optional[dict] = None, - audio_token_id: Optional[int] = None, + text_config: dict | None = None, + audio_token_id: int | None = None, **kwargs, ): self.audio_encoder_config = DashengConfig(**(audio_encoder_config or {})) diff --git a/vllm/transformers_utils/configs/mlp_speculator.py b/vllm/transformers_utils/configs/mlp_speculator.py index 45d76a8fdf26..75745f227f48 100644 --- a/vllm/transformers_utils/configs/mlp_speculator.py +++ b/vllm/transformers_utils/configs/mlp_speculator.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional from transformers import PretrainedConfig @@ -19,7 +18,7 @@ def __init__( emb_dim: int = 4096, inner_dim: int = 0, n_predict: int = 3, - top_k_tokens_per_head: Optional[list[int]] = None, + top_k_tokens_per_head: list[int] | None = None, n_candidates: int = 5, tie_weights: bool = False, scale_input: bool = False, diff --git a/vllm/transformers_utils/configs/nemotron_h.py b/vllm/transformers_utils/configs/nemotron_h.py index c8b6784d6a8e..68c40002098c 100644 --- a/vllm/transformers_utils/configs/nemotron_h.py +++ b/vllm/transformers_utils/configs/nemotron_h.py @@ -185,6 +185,15 @@ def __init__( mamba_proj_bias=False, mamba_chunk_size=256, rescale_prenorm_residual=True, + n_routed_experts=8, + n_shared_experts=1, + moe_intermediate_size=7688, + moe_shared_expert_intermediate_size=7688, + num_experts_per_tok=2, + routed_scaling_factor=1.0, + n_group=1, + topk_group=1, + norm_topk_prob=True, **kwargs, ): self.vocab_size = vocab_size @@ -241,6 +250,15 @@ def __init__( self.mamba_proj_bias = mamba_proj_bias self.chunk_size = mamba_chunk_size self.rescale_prenorm_residual = rescale_prenorm_residual + self.n_routed_experts = n_routed_experts + self.n_shared_experts = n_shared_experts + self.moe_intermediate_size = moe_intermediate_size + self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size # noqa: E501 + self.num_experts_per_tok = num_experts_per_tok + self.routed_scaling_factor = routed_scaling_factor + self.n_group = n_group + self.topk_group = topk_group + self.norm_topk_prob = norm_topk_prob super().__init__( pad_token_id=pad_token_id, @@ -258,5 +276,7 @@ def layers_block_type(self): else "attention" if self.hybrid_override_pattern[i] == "*" else "mlp" + if self.hybrid_override_pattern[i] == "-" + else "moe" for i in range(self.num_hidden_layers) ] diff --git a/vllm/transformers_utils/configs/ovis.py b/vllm/transformers_utils/configs/ovis.py index 404fa700a26c..294b4c9037aa 100644 --- a/vllm/transformers_utils/configs/ovis.py +++ b/vllm/transformers_utils/configs/ovis.py @@ -5,7 +5,7 @@ # adapted from https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_aimv2.py # and https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_ovis.py # Ovis Config with AimV2 config registration removed for Transformers compatibility -from typing import Any, Optional, Union +from typing import Any from transformers import AutoConfig, PretrainedConfig @@ -76,7 +76,7 @@ def __init__( tau=1.0, depths=None, drop_cls_token=False, - backbone_config: Optional[Union[PretrainedConfig, dict]] = None, + backbone_config: PretrainedConfig | dict | None = None, hidden_stride: int = 1, **kwargs, ): @@ -142,8 +142,8 @@ class OvisConfig(PretrainedConfig): def __init__( self, - llm_config: Optional[Union[PretrainedConfig, dict]] = None, - visual_tokenizer_config: Optional[Union[PretrainedConfig, dict]] = None, + llm_config: PretrainedConfig | dict | None = None, + visual_tokenizer_config: PretrainedConfig | dict | None = None, multimodal_max_length=8192, hidden_size=None, conversation_formatter_class=None, diff --git a/vllm/transformers_utils/configs/radio.py b/vllm/transformers_utils/configs/radio.py index f13598034bae..2b6544fb273c 100644 --- a/vllm/transformers_utils/configs/radio.py +++ b/vllm/transformers_utils/configs/radio.py @@ -2,8 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Radio vision model configuration""" -from typing import Optional, Union - from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging @@ -60,9 +58,9 @@ def __init__( initializer_factor: float = 1.0, hidden_act: str = "gelu", max_img_size: int = 2048, - norm_mean: Union[tuple[float, float, float], list] = OPENAI_CLIP_MEAN, - norm_std: Union[tuple[float, float, float], list] = OPENAI_CLIP_STD, - reg_tokens: Optional[int] = None, + norm_mean: tuple[float, float, float] | list = OPENAI_CLIP_MEAN, + norm_std: tuple[float, float, float] | list = OPENAI_CLIP_STD, + reg_tokens: int | None = None, **kwargs, ): self.model_name = model_name diff --git a/vllm/transformers_utils/configs/speculators/base.py b/vllm/transformers_utils/configs/speculators/base.py index 1c415a43360e..bf3a5d413192 100644 --- a/vllm/transformers_utils/configs/speculators/base.py +++ b/vllm/transformers_utils/configs/speculators/base.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import Any, Union +from typing import Any from transformers import PretrainedConfig @@ -18,7 +18,7 @@ class SpeculatorsConfig(PretrainedConfig): @classmethod def from_pretrained( cls, - pretrained_model_name_or_path: Union[str, os.PathLike], + pretrained_model_name_or_path: str | os.PathLike, **kwargs, ) -> "SpeculatorsConfig": """Load speculators Eagle config and convert to vLLM format.""" diff --git a/vllm/transformers_utils/configs/step3_vl.py b/vllm/transformers_utils/configs/step3_vl.py index 36d39e828a93..637b82d88e26 100644 --- a/vllm/transformers_utils/configs/step3_vl.py +++ b/vllm/transformers_utils/configs/step3_vl.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional, Union +from typing import Any from transformers.configuration_utils import PretrainedConfig @@ -53,7 +53,7 @@ def __init__( moe_num_experts: int = 48, moe_top_k: int = 3, rope_theta: float = 500000, - rope_scaling: Optional[dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, max_position_embedding: int = 65536, share_expert_dim: int = 5120, share_q_dim: int = 2048, @@ -147,8 +147,8 @@ class Step3VLConfig(PretrainedConfig): def __init__( self, - vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None, - text_config: Optional[Union[dict, Step3TextConfig]] = None, + vision_config: dict | Step3VisionEncoderConfig | None = None, + text_config: dict | Step3TextConfig | None = None, understand_projector_stride: int = 1, projector_bias: bool = True, image_token_id: int = 128001, diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py index ac22304e9125..fc0360a9ecb4 100644 --- a/vllm/transformers_utils/configs/ultravox.py +++ b/vllm/transformers_utils/configs/ultravox.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_config.py -from typing import Any, Optional +from typing import Any import transformers @@ -50,10 +50,10 @@ class UltravoxConfig(transformers.PretrainedConfig): def __init__( self, - audio_config: Optional[dict[str, Any]] = None, - text_config: Optional[dict[str, Any]] = None, - audio_model_id: Optional[str] = None, - text_model_id: Optional[str] = None, + audio_config: dict[str, Any] | None = None, + text_config: dict[str, Any] | None = None, + audio_model_id: str | None = None, + text_model_id: str | None = None, ignore_index: int = -100, audio_token_index: int = 32000, hidden_size: int = 4096, diff --git a/vllm/transformers_utils/detokenizer_utils.py b/vllm/transformers_utils/detokenizer_utils.py index 60742ae97d5d..560526bfd823 100644 --- a/vllm/transformers_utils/detokenizer_utils.py +++ b/vllm/transformers_utils/detokenizer_utils.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional from .tokenizer import AnyTokenizer -def _replace_none_with_empty(tokens: list[Optional[str]]): +def _replace_none_with_empty(tokens: list[str | None]): for i, token in enumerate(tokens): if token is None: tokens[i] = "" @@ -111,7 +110,7 @@ def convert_ids_list_to_tokens( def detokenize_incrementally( tokenizer: AnyTokenizer, all_input_ids: list[int], - prev_tokens: Optional[list[str]], + prev_tokens: list[str] | None, prefix_offset: int, read_offset: int, skip_special_tokens: bool = False, diff --git a/vllm/transformers_utils/dynamic_module.py b/vllm/transformers_utils/dynamic_module.py index 3c273ad41da0..24ead83785f7 100644 --- a/vllm/transformers_utils/dynamic_module.py +++ b/vllm/transformers_utils/dynamic_module.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import Optional, Union from transformers.dynamic_module_utils import get_class_from_dynamic_module @@ -14,18 +13,18 @@ def try_get_class_from_dynamic_module( class_reference: str, pretrained_model_name_or_path: str, - cache_dir: Optional[Union[str, os.PathLike]] = None, + cache_dir: str | os.PathLike | None = None, force_download: bool = False, - resume_download: Optional[bool] = None, - proxies: Optional[dict[str, str]] = None, - token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, + resume_download: bool | None = None, + proxies: dict[str, str] | None = None, + token: bool | str | None = None, + revision: str | None = None, local_files_only: bool = False, - repo_type: Optional[str] = None, - code_revision: Optional[str] = None, + repo_type: str | None = None, + code_revision: str | None = None, warn_on_fail: bool = True, **kwargs, -) -> Optional[type]: +) -> type | None: """ As `transformers.dynamic_module_utils.get_class_from_dynamic_module`, but ignoring any errors. diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 81f9b76b5ef7..98eb9cf33595 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from functools import lru_cache -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, cast from transformers import ( AutoFeatureExtractor, @@ -16,7 +16,7 @@ from transformers.video_processing_utils import BaseVideoProcessor from typing_extensions import TypeVar -from vllm.utils import get_allowed_kwarg_only_overrides +from vllm.utils.func_utils import get_allowed_kwarg_only_overrides if TYPE_CHECKING: from vllm.config import ModelConfig @@ -45,7 +45,7 @@ def __hash__(self) -> int: # type: ignore[override] return hash(tuple(self)) -def _get_processor_factory_fn(processor_cls: Union[type, tuple[type, ...]]): +def _get_processor_factory_fn(processor_cls: type | tuple[type, ...]): if isinstance(processor_cls, tuple) or processor_cls == ProcessorMixin: return AutoProcessor.from_pretrained if hasattr(processor_cls, "from_pretrained"): @@ -56,7 +56,7 @@ def _get_processor_factory_fn(processor_cls: Union[type, tuple[type, ...]]): def _merge_mm_kwargs( model_config: "ModelConfig", - processor_cls: Union[type, tuple[type, ...]], + processor_cls: type | tuple[type, ...], /, **kwargs, ): @@ -86,9 +86,9 @@ def _merge_mm_kwargs( def get_processor( processor_name: str, *args: Any, - revision: Optional[str] = None, + revision: str | None = None, trust_remote_code: bool = False, - processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, + processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin, **kwargs: Any, ) -> _P: """Load a processor for the given model name via HuggingFace.""" @@ -146,7 +146,7 @@ def get_processor( def cached_processor_from_config( model_config: "ModelConfig", - processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, + processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin, **kwargs: Any, ) -> _P: return cached_get_processor( @@ -161,7 +161,7 @@ def cached_processor_from_config( def get_feature_extractor( processor_name: str, *args: Any, - revision: Optional[str] = None, + revision: str | None = None, trust_remote_code: bool = False, **kwargs: Any, ): @@ -211,7 +211,7 @@ def cached_feature_extractor_from_config( def get_image_processor( processor_name: str, *args: Any, - revision: Optional[str] = None, + revision: str | None = None, trust_remote_code: bool = False, **kwargs: Any, ): @@ -261,9 +261,9 @@ def cached_image_processor_from_config( def get_video_processor( processor_name: str, *args: Any, - revision: Optional[str] = None, + revision: str | None = None, trust_remote_code: bool = False, - processor_cls_overrides: Optional[type[_V]] = None, + processor_cls_overrides: type[_V] | None = None, **kwargs: Any, ): """Load a video processor for the given model name via HuggingFace.""" @@ -300,7 +300,7 @@ def get_video_processor( def cached_video_processor_from_config( model_config: "ModelConfig", - processor_cls: Optional[type[_V]] = None, + processor_cls: type[_V] | None = None, **kwargs: Any, ): return cached_get_video_processor( diff --git a/vllm/transformers_utils/processors/deepseek_ocr.py b/vllm/transformers_utils/processors/deepseek_ocr.py new file mode 100644 index 000000000000..bb7aa0c17486 --- /dev/null +++ b/vllm/transformers_utils/processors/deepseek_ocr.py @@ -0,0 +1,438 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# adapted from https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek-OCR-master/DeepSeek-OCR-vllm/process/image_process.py +import math + +import torch +import torchvision.transforms as T +from PIL import Image, ImageOps +from transformers import AutoProcessor, BatchFeature, LlamaTokenizerFast +from transformers.processing_utils import ProcessorMixin + +# TODO(Isotr0py): change modes for variants +# see: https://github.com/deepseek-ai/DeepSeek-OCR/blob/8cf003d38821fa1b19c73da3bd1b0dc262ea8136/DeepSeek-OCR-master/DeepSeek-OCR-vllm/config.py#L1-L6 +# Tiny: base_size = 512, image_size = 512, crop_mode = False +# Small: base_size = 640, image_size = 640, crop_mode = False +# Base: base_size = 1024, image_size = 1024, crop_mode = False +# Large: base_size = 1280, image_size = 1280, crop_mode = False +# Gundam: base_size = 1024, image_size = 640, crop_mode = True +BASE_SIZE = 1024 +IMAGE_SIZE = 640 +CROP_MODE = True + +# TODO(Isotr0py): Expose as mm_kwargs +MIN_CROPS = 2 +MAX_CROPS = 6 # max:9; If your GPU memory is small, it is recommended to set it to 6. + + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def calculate_aspect_ratios( + min_num: int = MIN_CROPS, max_num: int = MAX_CROPS +) -> list[tuple[int, int]]: + target_ratios: set[tuple[int, int]] = set( + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + ) + sorted_target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + return sorted_target_ratios + + +def count_tiles( + orig_width, + orig_height, + min_num=MIN_CROPS, + max_num=MAX_CROPS, + image_size=640, + use_thumbnail=False, +): + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = calculate_aspect_ratios(min_num, max_num) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + + return target_aspect_ratio + + +def dynamic_preprocess( + image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False +): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = calculate_aspect_ratios(min_num, max_num) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images, target_aspect_ratio + + +class ImageTransform: + def __init__( + self, + mean: tuple[float, float, float] = (0.5, 0.5, 0.5), + std: tuple[float, float, float] = (0.5, 0.5, 0.5), + normalize: bool = True, + ): + self.mean = mean + self.std = std + self.normalize = normalize + + transform_pipelines = [T.ToTensor()] + + if normalize: + transform_pipelines.append(T.Normalize(mean, std)) + + self.transform = T.Compose(transform_pipelines) + + def __call__(self, pil_img: Image.Image): + x = self.transform(pil_img) + return x + + +class DeepseekOCRProcessor(ProcessorMixin): + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + attributes = ["tokenizer"] + + def __init__( + self, + tokenizer: LlamaTokenizerFast, + patch_size: int = 16, + downsample_ratio: int = 4, + image_mean: tuple[float, float, float] = (0.5, 0.5, 0.5), + image_std: tuple[float, float, float] = (0.5, 0.5, 0.5), + normalize: bool = True, + image_token: str = "<image>", + pad_token: str = "<|▁pad▁|>", + add_special_token: bool = False, + sft_format: str = "deepseek", + mask_prompt: bool = True, + ignore_id: int = -100, + **kwargs, + ): + self.image_size = IMAGE_SIZE + self.base_size = BASE_SIZE + self.patch_size = 16 + self.image_mean = image_mean + self.image_std = image_std + self.normalize = normalize + self.downsample_ratio = 4 + + self.image_transform = ImageTransform( + mean=image_mean, std=image_std, normalize=normalize + ) + + self.tokenizer = tokenizer + self.tokenizer.padding_side = "left" # must set this,padding side with make a difference in batch inference # noqa: E501 + + # add the pad_token as special token to use 'tokenizer.pad_token' + # and 'tokenizer.pad_token_id' + if self.tokenizer.pad_token is None: + self.tokenizer.add_special_tokens({"pad_token": pad_token}) + + # add image token + self.image_token_id = self.tokenizer.vocab.get(image_token) + self.image_token = image_token + self.pad_token = pad_token + self.add_special_token = add_special_token + self.sft_format = sft_format + self.mask_prompt = mask_prompt + self.ignore_id = ignore_id + + super().__init__( + tokenizer, + **kwargs, + ) + + @property + def bos_id(self): + return self.tokenizer.bos_token_id + + @property + def eos_id(self): + return self.tokenizer.eos_token_id + + @property + def pad_id(self): + return self.tokenizer.pad_token_id + + def encode(self, text: str, bos: bool = True, eos: bool = False): + t = self.tokenizer.encode(text, add_special_tokens=False) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: list[int], **kwargs) -> str: + return self.tokenizer.decode(t, **kwargs) + + def process_one( + self, + prompt: str, + images: list[Image.Image], + crop_mode: bool = CROP_MODE, + ): + """ + + Args: + prompt (str): the formatted prompt; + images (List[ImageType]): the list of images; + crop_mode (bool): if True, then crop the image; + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - target_ids (torch.LongTensor): [N + image tokens] + - pixel_values (torch.FloatTensor): [n_patches, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + assert prompt is not None and images is not None, ( + "prompt and images must be used at the same time." + ) + + sft_format = prompt + + ( + input_ids, + pixel_values, + images_crop, + images_seq_mask, + images_spatial_crop, + num_image_tokens, + _, + ) = self.tokenize_with_images( + conversation=sft_format, + images=images, + bos=True, + eos=True, + cropping=crop_mode, + ) + + prepare = BatchFeature( + data=dict( + input_ids=input_ids, + pixel_values=pixel_values, + images_crop=images_crop, + images_seq_mask=images_seq_mask, + images_spatial_crop=images_spatial_crop, + num_image_tokens=num_image_tokens, + ), + tensor_type="pt", + ) + return prepare + + def __call__( + self, + *, + prompt: str, + images: list[Image.Image], + crop_mode: bool = CROP_MODE, + **kwargs, + ): + prepare = self.process_one( + prompt=prompt, + images=images, + crop_mode=crop_mode, + ) + + return prepare + + def tokenize_with_images( + self, + conversation: str, + images: list[Image.Image], + bos: bool = True, + eos: bool = True, + cropping: bool = True, + ): + """Tokenize text with <image> tags.""" + + assert conversation.count(self.image_token) == len(images) + text_splits = conversation.split(self.image_token) + images_list, images_crop_list, images_seq_mask, images_spatial_crop = ( + [], + [], + [], + [], + ) + image_shapes = [] + num_image_tokens = [] + tokenized_str = [] + for text_sep, image in zip(text_splits, images): + tokenized_sep = self.encode(text_sep, bos=False, eos=False) + tokenized_str += tokenized_sep + images_seq_mask += [False] * len(tokenized_sep) + + image_shapes.append(image.size) + + images_crop_raw = [] + if image.size[0] <= 640 and image.size[1] <= 640: + crop_ratio = [1, 1] + elif cropping: + images_crop_raw, crop_ratio = dynamic_preprocess( + image, image_size=IMAGE_SIZE + ) + else: + crop_ratio = [1, 1] + + if self.image_size <= 640 and not cropping: + image = image.resize((self.image_size, self.image_size)) + + global_view = ImageOps.pad( + image, + (self.base_size, self.base_size), + color=tuple(int(x * 255) for x in self.image_transform.mean), + ) + images_list.append(self.image_transform(global_view)) + + num_width_tiles, num_height_tiles = crop_ratio + images_spatial_crop.append([num_width_tiles, num_height_tiles]) + + if num_width_tiles > 1 or num_height_tiles > 1: + for cropped_image in images_crop_raw: + images_crop_list.append(self.image_transform(cropped_image)) + + num_queries = math.ceil( + (self.image_size // self.patch_size) / self.downsample_ratio + ) + num_queries_base = math.ceil( + (self.base_size // self.patch_size) / self.downsample_ratio + ) + + tokenized_image = ( + [self.image_token_id] * num_queries_base + [self.image_token_id] + ) * num_queries_base + tokenized_image += [self.image_token_id] + if num_width_tiles > 1 or num_height_tiles > 1: + local_row = [self.image_token_id] * (num_queries * num_width_tiles + 1) + tokenized_image += local_row * (num_queries * num_height_tiles) + tokenized_str += tokenized_image + images_seq_mask += [True] * len(tokenized_image) + num_image_tokens.append(len(tokenized_image)) + + """process the last text split""" + tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False) + tokenized_str += tokenized_sep + images_seq_mask += [False] * len(tokenized_sep) + + """add the bos and eos tokens""" + if bos: + tokenized_str = [self.bos_id] + tokenized_str + images_seq_mask = [False] + images_seq_mask + if eos: + tokenized_str = tokenized_str + [self.eos_id] + images_seq_mask = images_seq_mask + [False] + + assert len(tokenized_str) == len(images_seq_mask), ( + f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} " + f"is not equal to images_seq_mask's length {len(images_seq_mask)}." + ) + + masked_tokenized_str = [] + for token_index in tokenized_str: + if token_index != self.image_token_id: + masked_tokenized_str.append(token_index) + else: + masked_tokenized_str.append(self.ignore_id) + + assert ( + len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str) + ), ( + f"tokenized_str's length {len(tokenized_str)}, " + f"input_ids' length {len(masked_tokenized_str)}, " + f"images_seq_mask's length {len(images_seq_mask)}, are not equal." + ) + + input_ids = torch.LongTensor(tokenized_str) + target_ids = torch.LongTensor(masked_tokenized_str) + images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) + + # set input_ids < 0 | input_ids == self.image_token_id as ignore_id + target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = ( + self.ignore_id + ) + input_ids[input_ids < 0] = self.pad_id + + # Remove the ending eos token + assert input_ids[-1] == self.eos_id + input_ids = input_ids[:-1] + target_ids = target_ids[:-1] + images_seq_mask = images_seq_mask[:-1] + + if len(images_list) == 0: + pixel_values = torch.zeros((0, 3, self.base_size, self.base_size)) + images_spatial_crop = torch.zeros((0, 2), dtype=torch.long) + images_crop = torch.zeros((0, 3, self.image_size, self.image_size)) + else: + pixel_values = torch.stack(images_list, dim=0) + images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) + if images_crop_list: + images_crop = torch.stack(images_crop_list, dim=0) + else: + images_crop = torch.zeros((0, 3, self.image_size, self.image_size)) + + input_ids = input_ids.unsqueeze(0) + + return ( + input_ids, + pixel_values, + images_crop, + images_seq_mask, + images_spatial_crop, + num_image_tokens, + image_shapes, + ) + + +AutoProcessor.register("DeepseekOCRProcessor", DeepseekOCRProcessor) diff --git a/vllm/transformers_utils/processors/ovis.py b/vllm/transformers_utils/processors/ovis.py index 58c1b1a91658..252f83399365 100644 --- a/vllm/transformers_utils/processors/ovis.py +++ b/vllm/transformers_utils/processors/ovis.py @@ -23,7 +23,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import cached_property -from typing import Union import PIL import torch @@ -104,9 +103,10 @@ def extra_special_tokens(self): def __call__( self, images: ImageInput = None, - text: Union[ - TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput] - ] = None, + text: TextInput + | PreTokenizedInput + | list[TextInput] + | list[PreTokenizedInput] = None, **kwargs: Unpack[OvisProcessorKwargs], ) -> BatchFeature: """ diff --git a/vllm/transformers_utils/processors/ovis2_5.py b/vllm/transformers_utils/processors/ovis2_5.py index fba26d1d0304..4c084fdccabc 100644 --- a/vllm/transformers_utils/processors/ovis2_5.py +++ b/vllm/transformers_utils/processors/ovis2_5.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from functools import cached_property -from typing import Optional, Union import numpy as np import PIL @@ -99,10 +98,11 @@ def extra_special_tokens(self): def __call__( self, images: ImageInput = None, - videos: Union[np.ndarray, list[ImageInput]] = None, - text: Union[ - TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput] - ] = None, + videos: np.ndarray | list[ImageInput] = None, + text: TextInput + | PreTokenizedInput + | list[TextInput] + | list[PreTokenizedInput] = None, **kwargs: Unpack[Ovis2_5ProcessorKwargs], ) -> BatchFeature: """ @@ -376,12 +376,12 @@ def construct_visual_placeholders(self, grid, is_video: bool = False): def preprocess_multidata( self, - images: Optional[Union[PIL.Image.Image, list[PIL.Image.Image]]] = None, - video: Optional[Union[list[PIL.Image.Image], np.ndarray]] = None, - convert_to_rgb: Optional[bool] = True, + images: PIL.Image.Image | list[PIL.Image.Image] | None = None, + video: list[PIL.Image.Image] | np.ndarray | None = None, + convert_to_rgb: bool | None = True, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS, - return_tensors: Optional[str] = "pt", + return_tensors: str | None = "pt", ): is_video = False if images is not None: @@ -397,6 +397,8 @@ def preprocess_multidata( images.append(image) elif isinstance(video, list): images = video + else: + raise ValueError("Either images or video should be provided.") min_pixels = min( max_pixels if max_pixels is not None else MAX_PIXELS, min_pixels if min_pixels is not None else MIN_PIXELS, diff --git a/vllm/transformers_utils/runai_utils.py b/vllm/transformers_utils/runai_utils.py index ec60d66e5cff..eac4294bb59c 100644 --- a/vllm/transformers_utils/runai_utils.py +++ b/vllm/transformers_utils/runai_utils.py @@ -5,12 +5,11 @@ import os import shutil import signal -from typing import Optional from vllm import envs from vllm.assets.base import get_cache_dir from vllm.logger import init_logger -from vllm.utils import PlaceholderModule +from vllm.utils.import_utils import PlaceholderModule logger = init_logger(__name__) @@ -88,8 +87,8 @@ def new_handler(signum, frame): def pull_files( self, model_path: str = "", - allow_pattern: Optional[list[str]] = None, - ignore_pattern: Optional[list[str]] = None, + allow_pattern: list[str] | None = None, + ignore_pattern: list[str] | None = None, ) -> None: """ Pull files from object storage into the temporary directory. diff --git a/vllm/transformers_utils/s3_utils.py b/vllm/transformers_utils/s3_utils.py index ef30efd80b1f..a5a3af6538b8 100644 --- a/vllm/transformers_utils/s3_utils.py +++ b/vllm/transformers_utils/s3_utils.py @@ -4,7 +4,7 @@ import fnmatch from typing import TYPE_CHECKING, Optional -from vllm.utils import PlaceholderModule +from vllm.utils.import_utils import PlaceholderModule if TYPE_CHECKING: from botocore.client import BaseClient @@ -34,7 +34,7 @@ def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]: def glob( s3: Optional["BaseClient"] = None, path: str = "", - allow_pattern: Optional[list[str]] = None, + allow_pattern: list[str] | None = None, ) -> list[str]: """ List full file names from S3 path and filter by allow pattern. @@ -58,8 +58,8 @@ def glob( def list_files( s3: "BaseClient", path: str, - allow_pattern: Optional[list[str]] = None, - ignore_pattern: Optional[list[str]] = None, + allow_pattern: list[str] | None = None, + ignore_pattern: list[str] | None = None, ) -> tuple[str, str, list[str]]: """ List files from S3 path and filter by pattern. diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 9537295c6dcd..a393568909d2 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -7,7 +7,7 @@ import warnings from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, TypeAlias import huggingface_hub from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast @@ -21,23 +21,21 @@ if TYPE_CHECKING: from vllm.config import ModelConfig - from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer_base import TokenizerBase else: ModelConfig = Any - LoRARequest = Any TokenizerBase = Any logger = init_logger(__name__) -AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, TokenizerBase] +AnyTokenizer: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast | TokenizerBase def decode_tokens( tokenizer: AnyTokenizer, token_ids: list[int], *, - skip_special_tokens: Optional[bool] = None, + skip_special_tokens: bool | None = None, ) -> str: """ Backend-agnostic equivalent of HF's @@ -56,9 +54,9 @@ def encode_tokens( tokenizer: AnyTokenizer, text: str, *, - truncation: Optional[bool] = None, - max_length: Optional[int] = None, - add_special_tokens: Optional[bool] = None, + truncation: bool | None = None, + max_length: int | None = None, + add_special_tokens: bool | None = None, ) -> list[int]: """ Backend-agnostic equivalent of HF's @@ -137,12 +135,12 @@ def __reduce__(self): def get_tokenizer( - tokenizer_name: Union[str, Path], + tokenizer_name: str | Path, *args, tokenizer_mode: str = "auto", trust_remote_code: bool = False, - revision: Optional[str] = None, - download_dir: Optional[str] = None, + revision: str | None = None, + download_dir: str | None = None, **kwargs, ) -> AnyTokenizer: """Gets a tokenizer for the given model name via HuggingFace or ModelScope.""" diff --git a/vllm/transformers_utils/tokenizer_base.py b/vllm/transformers_utils/tokenizer_base.py index 2d64265abbf2..7421eb534808 100644 --- a/vllm/transformers_utils/tokenizer_base.py +++ b/vllm/transformers_utils/tokenizer_base.py @@ -3,7 +3,7 @@ import importlib from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from vllm.entrypoints.chat_utils import ChatCompletionMessageParam @@ -71,11 +71,11 @@ def __len__(self) -> int: @abstractmethod def __call__( self, - text: Union[str, list[str], list[int]], - text_pair: Optional[str] = None, + text: str | list[str] | list[int], + text_pair: str | None = None, add_special_tokens: bool = False, truncation: bool = False, - max_length: Optional[int] = None, + max_length: int | None = None, ): raise NotImplementedError() @@ -92,7 +92,7 @@ def encode_one( self, text: str, truncation: bool = False, - max_length: Optional[int] = None, + max_length: int | None = None, ) -> list[int]: raise NotImplementedError() @@ -100,9 +100,9 @@ def encode_one( def encode( self, text: str, - truncation: Optional[bool] = None, - max_length: Optional[int] = None, - add_special_tokens: Optional[bool] = None, + truncation: bool | None = None, + max_length: int | None = None, + add_special_tokens: bool | None = None, ) -> list[int]: raise NotImplementedError() @@ -110,7 +110,7 @@ def encode( def apply_chat_template( self, messages: list["ChatCompletionMessageParam"], - tools: Optional[list[dict[str, Any]]] = None, + tools: list[dict[str, Any]] | None = None, **kwargs, ) -> list[int]: raise NotImplementedError() @@ -120,9 +120,7 @@ def convert_tokens_to_string(self, tokens: list[str]) -> str: raise NotImplementedError() @abstractmethod - def decode( - self, ids: Union[list[int], int], skip_special_tokens: bool = True - ) -> str: + def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str: raise NotImplementedError() @abstractmethod diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 5633a31455e9..6f710bf23360 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -1,34 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os -from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union, cast - -import huggingface_hub -import regex as re -from huggingface_hub import HfApi, hf_hub_download -from transformers.tokenization_utils_base import BatchEncoding +from typing import TYPE_CHECKING, Any, cast from vllm.logger import init_logger from vllm.transformers_utils.tokenizer_base import TokenizerBase -from vllm.utils import is_list_of if TYPE_CHECKING: - # make sure `mistral_common` is lazy imported, - # so that users who only use non-mistral models - # will not be bothered by the dependency. - from mistral_common.protocol.instruct.request import ChatCompletionRequest - from mistral_common.tokens.tokenizers.mistral import ( - MistralTokenizer as PublicMistralTokenizer, + from mistral_common.protocol.instruct.request import ( + ChatCompletionRequest as MistralChatCompletionRequest, + ) + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + from transformers.tokenization_mistral_common import ( + MistralCommonTokenizer as TransformersMistralTokenizer, ) from vllm.entrypoints.chat_utils import ChatCompletionMessageParam + from vllm.entrypoints.openai.protocol import ChatCompletionRequest logger = init_logger(__name__) -def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): +def maybe_serialize_tool_calls(request: "MistralChatCompletionRequest"): # SEE: https://github.com/vllm-project/vllm/pull/9951 # Credits go to: @gcalmettes # NOTE: There is currently a bug in pydantic where attributes @@ -65,7 +58,7 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): request.messages[i]["tool_calls"] = validated_tool_calls -def truncate_tool_call_ids(request: "ChatCompletionRequest"): +def truncate_tool_call_ids(request: "MistralChatCompletionRequest"): """Truncates tool call IDs for Mistral's ID requirements.""" for i, message in enumerate(request.messages): if message.get("role") == "assistant": @@ -95,85 +88,35 @@ def truncate_tool_call_ids(request: "ChatCompletionRequest"): request.messages[i]["tool_call_id"] = tool_call_id -def validate_request_params(request: "ChatCompletionRequest"): - if request.skip_special_tokens is not None and not request.skip_special_tokens: +def _prepare_apply_chat_template_tools_and_messages( + messages: list["ChatCompletionMessageParam"], + tools: list[dict[str, Any]] | None = None, + continue_final_message: bool = False, + add_generation_prompt: bool = False, +) -> tuple[list["ChatCompletionMessageParam"], list[dict[str, Any]] | None]: + if add_generation_prompt and continue_final_message: raise ValueError( - "skip_special_tokens=False is not supported for Mistral tokenizers." + "Cannot set both `add_generation_prompt` and " + "`continue_final_message` to True." ) - -def list_local_repo_files(repo_id: str, revision: Optional[str]) -> list[str]: - repo_cache = os.path.join( - huggingface_hub.constants.HF_HUB_CACHE, - huggingface_hub.constants.REPO_ID_SEPARATOR.join( - ["models", *repo_id.split("/")] - ), - ) - - if revision is None: - revision_file = os.path.join(repo_cache, "refs", "main") - if os.path.isfile(revision_file): - with open(revision_file) as file: - revision = file.read() - - if revision: - revision_dir = os.path.join(repo_cache, "snapshots", revision) - if os.path.isdir(revision_dir): - return os.listdir(revision_dir) - - return [] - - -def find_tokenizer_file(files: list[str]): - # Accept both versioned (tokenizer.model.v3) and unversioned - # (tokenizer.model) forms, plus tekken.json and tokenizer.mm.model - # variants. Previous pattern only matched the versioned variants. - file_pattern = re.compile( - r"^tokenizer\.model(\.v.*)?|tekken\.json|tokenizer\.mm\.model(\.v.*)?$" - ) - - matched_files = [file for file in files if file_pattern.match(file)] - if len(matched_files) > 1: - logger.warning( - "Multiple files matched pattern `%s`: %s. Using %s.", - file_pattern.pattern, - matched_files, - matched_files[0], + last_message = cast(dict[str, Any], messages[-1]) + # add_generation_prompt is directly handled by the tokenizer but we + # check if the user is trying to use it with a final assistant message + # which is probably not what they want. + # If add_generation_prompt is False, we don't need to check anything. + if add_generation_prompt and last_message["role"] == "assistant": + raise ValueError( + "Cannot set `add_generation_prompt` to True when " + "the last message is from the assistant. Consider " + "using `continue_final_message` instead." ) - elif len(matched_files) == 0: - raise OSError( - f"Found {len(matched_files)} files matching the " - f"pattern: `{file_pattern.pattern}`. Make sure that a Mistral " - f"tokenizer is present in {files}." + if continue_final_message and last_message["role"] != "assistant": + raise ValueError( + "Cannot set `continue_final_message` to True when " + "the last message is not from the assistant." ) - return matched_files[0] - - -def _aggregate_content(content: list) -> list[dict[str, Any]]: - aggregated_content: list[dict[str, Any]] = [] - for chunk in content: - if ( - chunk.get("type") == "text" - and aggregated_content - and aggregated_content[-1].get("type") == "text" - ): - aggregated_content[-1]["text"] += "\n\n" + chunk.get("text") - else: - aggregated_content.append(chunk) - if len(aggregated_content) == 1 and aggregated_content[0].get("type") == "text": - content = aggregated_content[0]["text"] - return content - - -def make_mistral_chat_completion_request( - messages: list["ChatCompletionMessageParam"], - tools: Optional[list[dict[str, Any]]] = None, -) -> "ChatCompletionRequest": - last_message = cast(dict[str, Any], messages[-1]) - if last_message["role"] == "assistant": - last_message["prefix"] = True - # mistral-common requires AssistantMessage content to be string [1]. # # [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80 @@ -181,13 +124,6 @@ def make_mistral_chat_completion_request( # Remove reasoning_content as unsupported by Mistral _ = message.pop("reasoning_content", None) # type: ignore - # Convert list text content to string - if message.get("role") in ("assistant", "tool"): - content: Any = message.get("content") - if isinstance(content, list): - content = _aggregate_content(content) - message["content"] = content - # The Mistral client, in comparison to the OpenAI client, requires the # "parameters" dict and the "description" string to be present # even if they are empty. @@ -200,108 +136,113 @@ def make_mistral_chat_completion_request( if function.get("description") is None: function["description"] = "" - from mistral_common.protocol.instruct.request import ChatCompletionRequest + return messages, tools - return ChatCompletionRequest(messages=messages, tools=tools) # type: ignore[type-var] +def validate_request_params(request: "ChatCompletionRequest"): + if request.chat_template is not None or request.chat_template_kwargs is not None: + raise ValueError("chat_template is not supported for Mistral tokenizers.") -class MistralTokenizer(TokenizerBase): - def __init__(self, tokenizer: "PublicMistralTokenizer") -> None: - self.mistral = tokenizer - self.instruct = tokenizer.instruct_tokenizer - _mistral_version_str = self.instruct.tokenizer.version.value - self.version: int = int(_mistral_version_str.split("v")[-1]) - tokenizer_ = tokenizer.instruct_tokenizer.tokenizer - from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy - from mistral_common.tokens.tokenizers.tekken import Tekkenizer +def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int: + from mistral_common.tokens.tokenizers.tekken import Tekkenizer - self.is_tekken = isinstance(tokenizer_, Tekkenizer) + assert isinstance(tokenizer, Tekkenizer), type(tokenizer) + + t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t + shift = tokenizer.num_special_tokens + try: + return shift + tokenizer._tekken_token2id_nospecial[t_bytes] + except KeyError: + t_str = t_bytes.decode("utf-8") + if t_str in tokenizer._special_tokens_reverse_vocab: + return tokenizer._special_tokens_reverse_vocab[t_str] + logger.warning( + "Failed to convert token %s to id, replacing with <unk>", t_bytes + ) + return tokenizer.unk_id + + +class MistralTokenizer(TokenizerBase): + def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None: from mistral_common.tokens.tokenizers.sentencepiece import ( SentencePieceTokenizer, ) + from mistral_common.tokens.tokenizers.tekken import Tekkenizer - self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer) - self._special_token_policy = ( - SpecialTokenPolicy.IGNORE if self.is_tekken else None - ) + self.transformers_tokenizer = tokenizer + self.mistral = tokenizer.tokenizer + self.instruct = self.mistral.instruct_tokenizer + self.tokenizer = self.instruct.tokenizer + + _mistral_version_str = str(self.tokenizer.version.value) + self.version: int = int(_mistral_version_str.split("v")[-1]) + + self.is_tekken = isinstance(self.tokenizer, Tekkenizer) + self.is_spm = isinstance(self.tokenizer, SentencePieceTokenizer) if not (self.is_tekken or self.is_spm): - raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") - - self._vocab = tokenizer_.vocab() - # Convert to a dict[str, int] to match protocol, but this is a lossy - # conversion. There may be multiple token ids that decode to the same - # string due to partial UTF-8 byte sequences being converted to � - self._vocab_dict = {token: idx for idx, token in enumerate(self._vocab)} - self.tokenizer = tokenizer_ + raise TypeError(f"Unsupported tokenizer: {type(self.tokenizer)}") + + # Reverse order to ensure that the lowest token id is kept. + self._vocab_dict = { + self.convert_ids_to_tokens([i], skip_special_tokens=False)[0]: i + for i in range(self.vocab_size - 1, -1, -1) + } + # Sort the dict for convenience + self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1])) + + # Vocab sorted by token id. + self._vocab = self.tokenizer._vocab self._max_token_id = self.vocab_size - 1 @classmethod def from_pretrained( - cls, path_or_repo_id: str, *, revision: Optional[str] = None + cls, path_or_repo_id: str, *, revision: str | None = None ) -> "MistralTokenizer": - if not Path(path_or_repo_id).exists(): - assert len(path_or_repo_id.split("/")) == 2, ( - "You have either provided a non-existent path: " - "{path_or_repo_id} or an invalid HF Hub repo id." - ) - tokenizer_file = cls._download_mistral_tokenizer_from_hf( - path_or_repo_id, revision - ) - elif Path(path_or_repo_id).is_dir(): - tokenizer_file_name = find_tokenizer_file(os.listdir(path_or_repo_id)) - tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name) - else: - assert Path(path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}" - tokenizer_file = str(Path(path_or_repo_id)) - - from mistral_common.tokens.tokenizers.mistral import ( - MistralTokenizer as PublicMistralTokenizer, + from transformers.tokenization_mistral_common import ( + MistralCommonTokenizer as TransformersMistralTokenizer, ) - mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file) - return cls(mistral_tokenizer) - - @staticmethod - def _download_mistral_tokenizer_from_hf( - tokenizer_name: str, revision: Optional[str] - ) -> str: - try: - hf_api = HfApi() - files = hf_api.list_repo_files(repo_id=tokenizer_name, revision=revision) - except ConnectionError as exc: - files = list_local_repo_files(repo_id=tokenizer_name, revision=revision) - - if len(files) == 0: - raise exc - - filename = find_tokenizer_file(files) - - tokenizer_file = hf_hub_download( - tokenizer_name, filename=filename, revision=revision + str_revision = "main" if revision is None else revision + return cls( + TransformersMistralTokenizer.from_pretrained( + path_or_repo_id, revision=str_revision + ) ) - return tokenizer_file # the following attributes are set to fit vLLM's design and are used # by the structured output backends. @property def all_special_tokens_extended(self) -> list[str]: - from mistral_common.tokens.tokenizers.base import SpecialTokens - - # tekken defines its own extended special tokens list - if hasattr(self.tokenizer, "SPECIAL_TOKENS"): - special_tokens = self.tokenizer.SPECIAL_TOKENS - else: - special_tokens = list(SpecialTokens) - return [s.value if isinstance(s, SpecialTokens) else s for s in special_tokens] + return self.all_special_tokens @property def all_special_tokens(self) -> list[str]: - return self.all_special_tokens_extended + from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy + + return [ + self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP) + for i in self.all_special_ids + ] @property def all_special_ids(self) -> list[int]: - return [self.all_special_tokens.index(t) for t in self.all_special_tokens] + from mistral_common.tokens.tokenizers.sentencepiece import ( + SentencePieceTokenizer, + ) + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + + if self.is_tekken: + assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer) + special_ids = {t["rank"] for t in self.tokenizer._all_special_tokens} + elif self.is_spm: + assert isinstance(self.tokenizer, SentencePieceTokenizer), type( + self.tokenizer + ) + special_ids = self.tokenizer._control_tokens + else: + raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}") + return sorted(special_ids) @property def bos_token_id(self) -> int: @@ -317,7 +258,7 @@ def sep_token(self) -> str: @property def pad_token(self) -> str: - raise NotImplementedError() + return self.transformers_tokenizer.pad_token @property def is_fast(self) -> bool: @@ -325,7 +266,7 @@ def is_fast(self) -> bool: @property def vocab_size(self) -> int: - return len(self._vocab) + return self.transformers_tokenizer.vocab_size @property def max_token_id(self) -> int: @@ -335,36 +276,47 @@ def max_token_id(self) -> int: def truncation_side(self) -> str: raise NotImplementedError() + def _is_special_token_id(self, token_id: int) -> bool: + from mistral_common.tokens.tokenizers.sentencepiece import ( + SentencePieceTokenizer, + ) + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + + if self.is_spm: + assert isinstance(self.tokenizer, SentencePieceTokenizer), type( + self.tokenizer + ) + return token_id in self.tokenizer._control_tokens + if self.is_tekken: + assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer) + return token_id < self.tokenizer.num_special_tokens + else: + raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}") + def __len__(self) -> int: return self.vocab_size def __call__( self, - text: Union[str, list[str], list[int]], - text_pair: Optional[str] = None, + text: str | list[str] | list[int], + text_pair: str | None = None, add_special_tokens: bool = False, truncation: bool = False, - max_length: Optional[int] = None, + max_length: int | None = None, ): - input_ids: Union[list[int], list[list[int]]] - # For list[str], original prompt text - if is_list_of(text, str): - input_ids_: list[list[int]] = [] - for p in text: - each_input_ids = self.encode_one(p, truncation, max_length) - input_ids_.append(each_input_ids) - input_ids = input_ids_ - # For list[int], apply chat template output, already tokens. - elif is_list_of(text, int): - input_ids = text - # For str, single prompt text - else: - input_ids = self.encode_one(text, truncation, max_length) - return BatchEncoding({"input_ids": input_ids}) + return self.transformers_tokenizer( + text=text, + text_pair=text_pair, + add_special_tokens=add_special_tokens, + truncation=truncation, + max_length=max_length, + ) + + @property + def vocab(self) -> list[str]: + return self._vocab def get_vocab(self) -> dict[str, int]: - # NB: the dictionary form of the vocabulary collapses token ids that map - # to the same string but have different bytes return self._vocab_dict def get_added_vocab(self) -> dict[str, int]: @@ -375,91 +327,112 @@ def encode_one( self, text: str, truncation: bool = False, - max_length: Optional[int] = None, + max_length: int | None = None, ) -> list[int]: # Mistral Tokenizers should not add special tokens - input_ids = self.encode(text) - - if truncation: - input_ids = input_ids[:max_length] - return input_ids + return self.transformers_tokenizer.encode( + text, add_special_tokens=False, truncation=truncation, max_length=max_length + ) def encode( self, text: str, - truncation: Optional[bool] = None, - max_length: Optional[int] = None, - add_special_tokens: Optional[bool] = None, + truncation: bool | None = None, + max_length: int | None = None, + add_special_tokens: bool | None = None, ) -> list[int]: - # `encode` should only be used for prompt completion - # it should never be used for chat_completion. - # For chat completion use `apply_chat_template` if add_special_tokens is not None: - return self.tokenizer.encode( - text, bos=add_special_tokens, eos=add_special_tokens + return self.transformers_tokenizer.encode( + text, + truncation=truncation, + max_length=max_length, + add_special_tokens=add_special_tokens, ) else: - return self.tokenizer.encode(text, bos=True, eos=False) + encoded = self.tokenizer.encode(text, bos=True, eos=False) + + if truncation is not False and max_length is not None: + return encoded[:max_length] + else: + return encoded def apply_chat_template( self, messages: list["ChatCompletionMessageParam"], - tools: Optional[list[dict[str, Any]]] = None, + tools: list[dict[str, Any]] | None = None, **kwargs, ) -> list[int]: - request = make_mistral_chat_completion_request(messages, tools) - encoded = self.mistral.encode_chat_completion(request) + add_generation_prompt = kwargs.pop("add_generation_prompt", False) + continue_final_message = kwargs.get("continue_final_message", False) + padding = kwargs.get("padding", False) + truncation = kwargs.get("truncation", False) + max_length = kwargs.get("max_length") + + messages, tools = _prepare_apply_chat_template_tools_and_messages( + messages, tools, continue_final_message, add_generation_prompt + ) - # encode-decode to get clean prompt - return encoded.tokens + return self.transformers_tokenizer.apply_chat_template( + conversation=messages, + tools=tools, + continue_final_message=continue_final_message, + tokenize=True, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=None, + return_dict=False, + ) + + def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str: + return self.transformers_tokenizer.decode( + ids, skip_special_tokens=skip_special_tokens + ) def convert_tokens_to_string(self, tokens: list[str]) -> str: - from mistral_common.tokens.tokenizers.base import SpecialTokens + from mistral_common.tokens.tokenizers.base import ( + SpecialTokenPolicy, + SpecialTokens, + ) + from mistral_common.tokens.tokenizers.sentencepiece import ( + SentencePieceTokenizer, + ) + from mistral_common.tokens.tokenizers.tekken import Tekkenizer + to_decode_special_tokens = {SpecialTokens.tool_calls} if self.is_tekken: + assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer) tokens = [ t for t in tokens - if ( - t is SpecialTokens.tool_calls - or t not in self.tokenizer._all_special_tokens - ) + if (t in to_decode_special_tokens or t not in self.all_special_tokens) ] if any(isinstance(t, bytes) for t in tokens): # we need to encode and decode all tokens again - shift = self.tokenizer.num_special_tokens - - def _token_to_id(t: str): - t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t - try: - return ( - shift + self.tokenizer._tekken_token2id_nospecial[t_bytes] - ) - except KeyError: - logger.warning( - "Failed to convert token %s to id, replacing with <unk>", - t_bytes, - ) - return self.tokenizer.unk_id - - ids = [_token_to_id(t) for t in tokens] - decoded = self.tokenizer.decode(ids, self._special_token_policy) + ids = [_tekken_token_to_id(self.tokenizer, t) for t in tokens] + # We filtered unwanted special tokens before + # so we can decode the rest. + decoded = self.tokenizer.decode(ids, SpecialTokenPolicy.KEEP) else: decoded = "".join(tokens) else: # make sure certain special tokens like Tool calls are # not decoded - special_tokens = {SpecialTokens.tool_calls} + assert isinstance(self.tokenizer, SentencePieceTokenizer), type( + self.tokenizer + ) + regular_tokens: list[str] = [] - decoded_list = [] + decoded_list: list[str] = [] + decoded = "" for token in tokens: - if token in special_tokens: + if token in to_decode_special_tokens: if regular_tokens: decoded_list.append( self.tokenizer.decode( - regular_tokens, self._special_token_policy + regular_tokens, SpecialTokenPolicy.IGNORE ) ) regular_tokens = [] @@ -469,66 +442,56 @@ def _token_to_id(t: str): if regular_tokens: decoded_list.append( - self.tokenizer.decode(regular_tokens, self._special_token_policy) + self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE) ) - decoded = "".join(decoded_list) return decoded - def decode( - self, ids: Union[list[int], int], skip_special_tokens: bool = True - ) -> str: - assert skip_special_tokens, ( - "skip_special_tokens=False is not supported for Mistral tokenizers." - ) - - if isinstance(ids, int): - ids = [ids] - return self.tokenizer.decode(ids, self._special_token_policy) - def convert_ids_to_tokens( self, ids: list[int], skip_special_tokens: bool = True, ) -> list[str]: - from mistral_common.tokens.tokenizers.base import SpecialTokens + from mistral_common.tokens.tokenizers.base import ( + SpecialTokenPolicy, + SpecialTokens, + ) from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13 - # TODO(Patrick) - potentially allow special tokens to not be skipped - assert skip_special_tokens, ( - "skip_special_tokens=False is not supported for Mistral tokenizers." - ) + if not skip_special_tokens: + return [self.tokenizer.id_to_piece(token_id) for token_id in ids] - assert self.is_tekken or self.is_spm, type(self.tokenizer) + non_skip_special_tokens_ids = { + self.tokenizer.get_control_token(SpecialTokens.tool_calls), + } + if isinstance(self.instruct, InstructTokenizerV13): + if self.instruct.BEGIN_THINK: + non_skip_special_tokens_ids.add(self.instruct.BEGIN_THINK) + if self.instruct.END_THINK: + non_skip_special_tokens_ids.add(self.instruct.END_THINK) - if self.is_tekken: - # skip special tokens except tool call and think tokens - non_skip_special_tokens = { - self.tokenizer.get_control_token(SpecialTokens.tool_calls) - } - if isinstance(self.instruct, InstructTokenizerV13): - if self.instruct.BEGIN_THINK: - non_skip_special_tokens.add(self.instruct.BEGIN_THINK) - if self.instruct.END_THINK: - non_skip_special_tokens.add(self.instruct.END_THINK) - ids = [ - i - for i in ids - if i > self.tokenizer.num_special_tokens or i in non_skip_special_tokens - ] + ids_kept = [ + i + for i in ids + if i in non_skip_special_tokens_ids or not self._is_special_token_id(i) + ] - tokens = [self.tokenizer.id_to_piece(id) for id in ids] + # We filtered unwanted special tokens so we can decode the rest. + tokens = [self.tokenizer.id_to_piece(token_id) for token_id in ids_kept] if any("�" in t for t in tokens) and self.is_tekken: # if a decoded token contains the replacement character, then the # token has an incomplete UTF-8 character so we must use bytes # See: https://github.com/vllm-project/vllm/pull/8640 # https://github.com/vllm-project/vllm/pull/9625 - # if underlying tokenizeir is sentencepiece, we just add "�" + # if underlying tokenizer is sentencepiece, we just add "�". + # We filtered unwanted special tokens so we can decode the rest. tokens = [ - self.tokenizer.id_to_byte_piece(id, self._special_token_policy) - for id in ids + self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP) + if token_id not in self.all_special_ids + else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP) + for token_id in ids_kept ] return tokens diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py index 8952a0b197d6..58c754dbd397 100644 --- a/vllm/transformers_utils/utils.py +++ b/vllm/transformers_utils/utils.py @@ -6,9 +6,9 @@ from functools import cache from os import PathLike from pathlib import Path -from typing import Any, Optional, Union +from typing import Any -from vllm.envs import VLLM_MODEL_REDIRECT_PATH +import vllm.envs as envs from vllm.logger import init_logger logger = init_logger(__name__) @@ -18,7 +18,7 @@ def is_s3(model_or_path: str) -> bool: return model_or_path.lower().startswith("s3://") -def check_gguf_file(model: Union[str, PathLike]) -> bool: +def check_gguf_file(model: str | PathLike) -> bool: """Check if the file is a GGUF model.""" model = Path(model) if not model.is_file(): @@ -38,8 +38,8 @@ def check_gguf_file(model: Union[str, PathLike]) -> bool: def modelscope_list_repo_files( repo_id: str, - revision: Optional[str] = None, - token: Union[str, bool, None] = None, + revision: str | None = None, + token: str | bool | None = None, ) -> list[str]: """List files in a modelscope repo.""" from modelscope.hub.api import HubApi @@ -57,7 +57,7 @@ def modelscope_list_repo_files( return files -def _maybe_json_dict(path: Union[str, PathLike]) -> dict[str, str]: +def _maybe_json_dict(path: str | PathLike) -> dict[str, str]: with open(path) as f: try: return json.loads(f.read()) @@ -65,7 +65,7 @@ def _maybe_json_dict(path: Union[str, PathLike]) -> dict[str, str]: return dict[str, str]() -def _maybe_space_split_dict(path: Union[str, PathLike]) -> dict[str, str]: +def _maybe_space_split_dict(path: str | PathLike) -> dict[str, str]: parsed_dict = dict[str, str]() with open(path) as f: for line in f.readlines(): @@ -86,7 +86,7 @@ def maybe_model_redirect(model: str) -> str: :return: maybe redirect to a local folder """ - model_redirect_path = VLLM_MODEL_REDIRECT_PATH + model_redirect_path = envs.VLLM_MODEL_REDIRECT_PATH if not model_redirect_path: return model @@ -104,7 +104,7 @@ def maybe_model_redirect(model: str) -> str: return model -def parse_safetensors_file_metadata(path: Union[str, PathLike]) -> dict[str, Any]: +def parse_safetensors_file_metadata(path: str | PathLike) -> dict[str, Any]: with open(path, "rb") as f: length_of_metadata = struct.unpack("<Q", f.read(8))[0] metadata = json.loads(f.read(length_of_metadata).decode("utf-8")) diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py index e1a509a303c5..f05bc555bfdc 100644 --- a/vllm/triton_utils/importing.py +++ b/vllm/triton_utils/importing.py @@ -98,3 +98,6 @@ def __init__(self): self.int64 = None self.int32 = None self.tensor = None + self.exp = None + self.log = None + self.log2 = None diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index ed470ebe8892..c8bff8b7c80b 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -10,7 +10,7 @@ from enum import Enum from pathlib import Path from threading import Thread -from typing import Any, Optional, Union +from typing import Any from uuid import uuid4 import cpuinfo @@ -21,7 +21,8 @@ import vllm.envs as envs from vllm.connections import global_http_connection from vllm.logger import init_logger -from vllm.utils import cuda_device_count_stateless, cuda_get_device_properties +from vllm.utils.platform_utils import cuda_get_device_properties +from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -32,7 +33,7 @@ _USAGE_STATS_ENABLED = None _USAGE_STATS_SERVER = envs.VLLM_USAGE_STATS_SERVER -_GLOBAL_RUNTIME_DATA = dict[str, Union[str, int, bool]]() +_GLOBAL_RUNTIME_DATA = dict[str, str | int | bool]() _USAGE_ENV_VARS_TO_COLLECT = [ "VLLM_USE_MODELSCOPE", @@ -46,7 +47,7 @@ ] -def set_runtime_usage_data(key: str, value: Union[str, int, bool]) -> None: +def set_runtime_usage_data(key: str, value: str | int | bool) -> None: """Set global usage data that will be sent with every usage heartbeat.""" _GLOBAL_RUNTIME_DATA[key] = value @@ -131,33 +132,33 @@ def __init__(self) -> None: self.uuid = str(uuid4()) # Environment Information - self.provider: Optional[str] = None - self.num_cpu: Optional[int] = None - self.cpu_type: Optional[str] = None - self.cpu_family_model_stepping: Optional[str] = None - self.total_memory: Optional[int] = None - self.architecture: Optional[str] = None - self.platform: Optional[str] = None - self.cuda_runtime: Optional[str] = None - self.gpu_count: Optional[int] = None - self.gpu_type: Optional[str] = None - self.gpu_memory_per_device: Optional[int] = None - self.env_var_json: Optional[str] = None + self.provider: str | None = None + self.num_cpu: int | None = None + self.cpu_type: str | None = None + self.cpu_family_model_stepping: str | None = None + self.total_memory: int | None = None + self.architecture: str | None = None + self.platform: str | None = None + self.cuda_runtime: str | None = None + self.gpu_count: int | None = None + self.gpu_type: str | None = None + self.gpu_memory_per_device: int | None = None + self.env_var_json: str | None = None # vLLM Information - self.model_architecture: Optional[str] = None - self.vllm_version: Optional[str] = None - self.context: Optional[str] = None + self.model_architecture: str | None = None + self.vllm_version: str | None = None + self.context: str | None = None # Metadata - self.log_time: Optional[int] = None - self.source: Optional[str] = None + self.log_time: int | None = None + self.source: str | None = None def report_usage( self, model_architecture: str, usage_context: UsageContext, - extra_kvs: Optional[dict[str, Any]] = None, + extra_kvs: dict[str, Any] | None = None, ) -> None: t = Thread( target=self._report_usage_worker, @@ -175,6 +176,32 @@ def _report_usage_worker( self._report_usage_once(model_architecture, usage_context, extra_kvs) self._report_continuous_usage() + def _report_tpu_inference_usage(self) -> bool: + try: + from tpu_inference import tpu_info, utils + + self.gpu_count = tpu_info.get_num_chips() + self.gpu_type = tpu_info.get_tpu_type() + self.gpu_memory_per_device = utils.get_device_hbm_limit() + self.cuda_runtime = "tpu_inference" + return True + except Exception: + return False + + def _report_torch_xla_usage(self) -> bool: + try: + import torch_xla + + self.gpu_count = torch_xla.runtime.world_size() + self.gpu_type = torch_xla.tpu.get_tpu_type() + self.gpu_memory_per_device = torch_xla.core.xla_model.get_memory_info()[ + "bytes_limit" + ] + self.cuda_runtime = "torch_xla" + return True + except Exception: + return False + def _report_usage_once( self, model_architecture: str, @@ -191,16 +218,10 @@ def _report_usage_once( ) if current_platform.is_cuda(): self.cuda_runtime = torch.version.cuda - if current_platform.is_tpu(): - try: - import torch_xla - - self.gpu_count = torch_xla.runtime.world_size() - self.gpu_type = torch_xla.tpu.get_tpu_type() - self.gpu_memory_per_device = torch_xla.core.xla_model.get_memory_info()[ - "bytes_limit" - ] - except Exception: + if current_platform.is_tpu(): # noqa: SIM102 + if (not self._report_tpu_inference_usage()) and ( + not self._report_torch_xla_usage() + ): logger.exception("Failed to collect TPU information") self.provider = _detect_cloud_provider() self.architecture = platform.machine() diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 4a6a79ad067b..38da04102c44 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -1,35 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - -import asyncio -import concurrent import contextlib import datetime import enum -import gc import getpass -import hashlib -import importlib -import importlib.metadata -import importlib.util import inspect -import ipaddress import json import multiprocessing import os -import pickle import signal -import socket -import subprocess import sys import tempfile import textwrap import threading -import time import traceback -import types import uuid import warnings import weakref @@ -41,64 +26,58 @@ RawDescriptionHelpFormatter, _ArgumentGroup, ) -from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task -from collections import UserDict, defaultdict -from collections.abc import ( - AsyncGenerator, - Awaitable, - Collection, - Generator, - Hashable, - Iterable, - Iterator, - Mapping, - Sequence, -) -from concurrent.futures import ThreadPoolExecutor -from concurrent.futures.process import ProcessPoolExecutor -from dataclasses import dataclass, field -from functools import cache, lru_cache, partial, wraps -from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Generic, - Literal, - TextIO, - TypeVar, - Union, -) -from urllib.parse import urlparse -from uuid import uuid4 +from collections import defaultdict +from collections.abc import Callable +from functools import partial, wraps +from typing import TYPE_CHECKING, Any, TypeVar -import cbor2 import cloudpickle -import numpy as np -import numpy.typing as npt import psutil import regex as re -import setproctitle import torch -import torch.types import yaml -import zmq -import zmq.asyncio -from packaging import version -from packaging.version import Version -from torch.library import Library -from transformers.tokenization_utils_base import BatchEncoding -from typing_extensions import Never, ParamSpec, TypeIs, assert_never import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger from vllm.ray.lazy_utils import is_in_ray_actor +from vllm.utils.platform_utils import cuda_is_initialized, xpu_is_initialized + +_DEPRECATED_MAPPINGS = { + "cprofile": "profiling", + "cprofile_context": "profiling", + "get_open_port": "network_utils", +} + + +def __getattr__(name: str) -> Any: # noqa: D401 - short deprecation docstring + """Module-level getattr to handle deprecated utilities.""" + if name in _DEPRECATED_MAPPINGS: + submodule_name = _DEPRECATED_MAPPINGS[name] + warnings.warn( + f"vllm.utils.{name} is deprecated and will be removed in a future version. " + f"Use vllm.utils.{submodule_name}.{name} instead.", + DeprecationWarning, + stacklevel=2, + ) + module = __import__(f"vllm.utils.{submodule_name}", fromlist=[submodule_name]) + return getattr(module, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__() -> list[str]: + # expose deprecated names in dir() for better UX/tab-completion + return sorted(list(globals().keys()) + list(_DEPRECATED_MAPPINGS.keys())) + if TYPE_CHECKING: from argparse import Namespace from vllm.config import ModelConfig, VllmConfig - from vllm.sequence import IntermediateTensors +else: + Namespace = object + + ModelConfig = object + VllmConfig = object logger = init_logger(__name__) @@ -123,61 +102,15 @@ STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" STR_INVALID_VAL: str = "INVALID" -MB_bytes = 1_000_000 -"""The number of bytes in one megabyte (MB).""" - -MiB_bytes = 1 << 20 -"""The number of bytes in one mebibyte (MiB).""" - -GB_bytes = 1_000_000_000 -"""The number of bytes in one gigabyte (GB).""" - -GiB_bytes = 1 << 30 -"""The number of bytes in one gibibyte (GiB).""" # ANSI color codes CYAN = "\033[1;36m" RESET = "\033[0;0m" -STR_DTYPE_TO_TORCH_DTYPE = { - "float32": torch.float32, - "half": torch.half, - "bfloat16": torch.bfloat16, - "float": torch.float, - "fp8": torch.uint8, - "fp8_e4m3": torch.uint8, - "fp8_e5m2": torch.uint8, - "int8": torch.int8, - "fp8_inc": torch.float8_e4m3fn, - "fp8_ds_mla": torch.uint8, -} - -TORCH_DTYPE_TO_NUMPY_DTYPE = { - torch.float16: np.float16, - torch.float32: np.float32, - torch.float64: np.float64, - torch.uint8: np.uint8, - torch.int32: np.int32, - torch.int64: np.int64, -} - - -@contextlib.contextmanager -def set_default_torch_num_threads(num_threads: int): - """Sets the default number of threads for PyTorch to the given value.""" - old_num_threads = torch.get_num_threads() - torch.set_num_threads(num_threads) - yield - torch.set_num_threads(old_num_threads) - -P = ParamSpec("P") T = TypeVar("T") U = TypeVar("U") -_K = TypeVar("_K", bound=Hashable) -_V = TypeVar("_V") - class Device(enum.Enum): GPU = enum.auto() @@ -202,527 +135,10 @@ def reset(self) -> None: self.counter = 0 -@cache -def get_max_shared_memory_bytes(gpu: int = 0) -> int: - """Returns the maximum shared memory per thread block in bytes.""" - from vllm import _custom_ops as ops - - max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu) - # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py - # will fail - assert max_shared_mem > 0, "max_shared_mem can not be zero" - return int(max_shared_mem) - - -def get_cpu_memory() -> int: - """Returns the total CPU memory of the node in bytes.""" - return psutil.virtual_memory().total - - def random_uuid() -> str: return str(uuid.uuid4().hex) -class AsyncMicrobatchTokenizer: - """Asynchronous tokenizer with micro-batching. - - Pulls pending encode/decode requests from a queue and batches them - up to reduce overhead. A single-thread ThreadPoolExecutor is used - so the event loop stays responsive. - """ - - def __init__( - self, - tokenizer, - max_batch_size: int = 32, - batch_wait_timeout_s: float = 0.002, - ) -> None: - self.tokenizer = tokenizer - self.max_batch_size = max_batch_size - self.batch_wait_timeout_s = batch_wait_timeout_s - - self._loop = asyncio.get_running_loop() - self._queues: dict[ - tuple, - asyncio.Queue[ - Union[ - tuple[str, dict, asyncio.Future], tuple[list[int], asyncio.Future] - ] - ], - ] = {} - self._batcher_tasks: list[asyncio.Task] = [] - - # Single-thread executor for blocking tokenizer calls. - self._executor = ThreadPoolExecutor(max_workers=1) - - # === Public async API === - async def __call__(self, prompt, **kwargs): - result_future: asyncio.Future = self._loop.create_future() - key = self._queue_key("encode", kwargs) - queue = self._get_queue(self._loop, key) - await queue.put((prompt, kwargs, result_future)) - return await result_future - - async def decode(self, token_ids, **kwargs): - result_future: asyncio.Future = self._loop.create_future() - key = self._queue_key("decode", kwargs) - queue = self._get_queue(self._loop, key) - await queue.put((token_ids, result_future)) - return await result_future - - # === Internal helpers === - def _get_queue( - self, loop: asyncio.AbstractEventLoop, key: tuple - ) -> asyncio.Queue[ - Union[tuple[str, dict, asyncio.Future], tuple[list[int], asyncio.Future]] - ]: - """Get the request queue for the given operation key, creating a new - queue and batcher task if needed.""" - queue = self._queues.get(key) - if queue is None: - self._queues[key] = queue = asyncio.Queue() - if key[0] == "encode": - can_batch = key[1] != "other" - coro = self._batch_encode_loop(queue, can_batch) - else: - assert key[0] == "decode", f"Unknown operation type: {key[0]}." - coro = self._batch_decode_loop(queue) - self._batcher_tasks.append(loop.create_task(coro)) - return queue - - async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool): - """Batch incoming encode requests for efficiency.""" - while True: - prompt, kwargs, result_future = await queue.get() - prompts = [prompt] - kwargs_list = [kwargs] - result_futures = [result_future] - deadline = self._loop.time() + self.batch_wait_timeout_s - - while len(prompts) < self.max_batch_size: - timeout = deadline - self._loop.time() - if timeout <= 0: - break - try: - prompt, kwargs, result_future = await asyncio.wait_for( - queue.get(), timeout - ) - prompts.append(prompt) - result_futures.append(result_future) - if not can_batch: - kwargs_list.append(kwargs) - except asyncio.TimeoutError: - break - - try: - # If every request uses identical kwargs we can run a single - # batched tokenizer call for a big speed-up. - if can_batch and len(prompts) > 1: - batch_encode_fn = partial(self.tokenizer, prompts, **kwargs) - results = await self._loop.run_in_executor( - self._executor, batch_encode_fn - ) - - for i, fut in enumerate(result_futures): - if not fut.done(): - data = {k: v[i] for k, v in results.items()} - fut.set_result(BatchEncoding(data)) - else: - encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [ - self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs) - ] - results = await self._loop.run_in_executor( - self._executor, encode_fn - ) - - for fut, res in zip(result_futures, results): - if not fut.done(): - fut.set_result(res) - except Exception as e: - for fut in result_futures: - if not fut.done(): - fut.set_exception(e) - - async def _batch_decode_loop(self, queue: asyncio.Queue): - """Batch incoming decode requests for efficiency.""" - while True: - token_ids, result_future = await queue.get() - token_ids_list = [token_ids] - result_futures = [result_future] - deadline = self._loop.time() + self.batch_wait_timeout_s - - while len(token_ids_list) < self.max_batch_size: - timeout = deadline - self._loop.time() - if timeout <= 0: - break - try: - token_ids, result_future = await asyncio.wait_for( - queue.get(), timeout - ) - token_ids_list.append(token_ids) - result_futures.append(result_future) - except asyncio.TimeoutError: - break - - try: - # Perform a single batched decode call for all requests - results = await self._loop.run_in_executor( - self._executor, self.tokenizer.batch_decode, token_ids_list - ) - for fut, res in zip(result_futures, results): - if not fut.done(): - fut.set_result(res) - except Exception as e: - for fut in result_futures: - if not fut.done(): - fut.set_exception(e) - - def _queue_key(self, op: str, kwargs: dict) -> tuple: - """ - Return a normalized key describing operation + kwargs. - - - `add_special_tokens`: {True/False} - - `truncation`: {True/False} - - If `truncation` is False (`max_length` is None), - returns a key for a can_batch queue. - - If `truncation` is True and `max_length` is None or equals - `tokenizer.model_max_length`, returns a key for a can_batch queue. - - Otherwise, returns a key for a cannot_batch queue. - - Examples: - - Decode: ("decode",) - - Encode typical: - ("encode", add_special_tokens, bool_truncation, max_length_label) - - Fallback: ("encode", "other") - """ - - if op == "decode": - return ("decode",) - - add_special_tokens = kwargs.get("add_special_tokens", True) - truncation = kwargs.get("truncation", False) - max_length = kwargs.get("max_length") - - if not truncation: - return "encode", add_special_tokens, False, None - - model_max = getattr(self.tokenizer, "model_max_length", None) - if max_length is None or (model_max is not None and max_length == model_max): - return "encode", add_special_tokens, True, "model_max" - - return "encode", "other" - - def __del__(self): - if ( - (tasks := getattr(self, "_batcher_tasks", None)) - and (loop := getattr(self, "_loop", None)) - and not loop.is_closed() - ): - - def cancel_tasks(): - for task in tasks: - task.cancel() - - loop.call_soon_threadsafe(cancel_tasks) - - -def cancel_task_threadsafe(task: Task): - if task and not task.done(): - run_in_loop(task.get_loop(), task.cancel) - - -def close_sockets(sockets: Sequence[Union[zmq.Socket, zmq.asyncio.Socket]]): - for sock in sockets: - if sock is not None: - sock.close(linger=0) - - -def run_in_loop(loop: AbstractEventLoop, function: Callable, *args): - if in_loop(loop): - function(*args) - elif not loop.is_closed(): - loop.call_soon_threadsafe(function, *args) - - -def in_loop(event_loop: AbstractEventLoop) -> bool: - try: - return asyncio.get_running_loop() == event_loop - except RuntimeError: - return False - - -def make_async( - func: Callable[P, T], executor: concurrent.futures.Executor | None = None -) -> Callable[P, Awaitable[T]]: - """Take a blocking function, and run it on in an executor thread. - - This function prevents the blocking function from blocking the - asyncio event loop. - The code in this function needs to be thread safe. - """ - - def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future: - loop = asyncio.get_event_loop() - p_func = partial(func, *args, **kwargs) - return loop.run_in_executor(executor=executor, func=p_func) - - return _async_wrapper - - -def _next_task(iterator: AsyncGenerator[T, None], loop: AbstractEventLoop) -> Task: - # Can use anext() in python >= 3.10 - return loop.create_task(iterator.__anext__()) # type: ignore[arg-type] - - -async def merge_async_iterators( - *iterators: AsyncGenerator[T, None], -) -> AsyncGenerator[tuple[int, T], None]: - """Merge multiple asynchronous iterators into a single iterator. - - This method handle the case where some iterators finish before others. - When it yields, it yields a tuple (i, item) where i is the index of the - iterator that yields the item. - """ - if len(iterators) == 1: - # Fast-path single iterator case. - async for item in iterators[0]: - yield 0, item - return - - loop = asyncio.get_running_loop() - - awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)} - try: - while awaits: - done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED) - for d in done: - pair = awaits.pop(d) - try: - item = await d - i, it = pair - awaits[_next_task(it, loop)] = pair - yield i, item - except StopAsyncIteration: - pass - finally: - # Cancel any remaining iterators - for f, (_, it) in awaits.items(): - with contextlib.suppress(BaseException): - f.cancel() - await it.aclose() - - -async def collect_from_async_generator(iterator: AsyncGenerator[T, None]) -> list[T]: - """Collect all items from an async generator into a list.""" - items = [] - async for item in iterator: - items.append(item) - return items - - -def get_ip() -> str: - host_ip = envs.VLLM_HOST_IP - if "HOST_IP" in os.environ and "VLLM_HOST_IP" not in os.environ: - logger.warning( - "The environment variable HOST_IP is deprecated and ignored, as" - " it is often used by Docker and other software to" - " interact with the container's network stack. Please " - "use VLLM_HOST_IP instead to set the IP address for vLLM processes" - " to communicate with each other." - ) - if host_ip: - return host_ip - - # IP is not set, try to get it from the network interface - - # try ipv4 - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable - return s.getsockname()[0] - except Exception: - pass - - # try ipv6 - try: - s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) - # Google's public DNS server, see - # https://developers.google.com/speed/public-dns/docs/using#addresses - s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable - return s.getsockname()[0] - except Exception: - pass - - warnings.warn( - "Failed to get the IP address, using 0.0.0.0 by default." - "The value can be set by the environment variable" - " VLLM_HOST_IP or HOST_IP.", - stacklevel=2, - ) - return "0.0.0.0" - - -def test_loopback_bind(address, family): - try: - s = socket.socket(family, socket.SOCK_DGRAM) - s.bind((address, 0)) # Port 0 = auto assign - s.close() - return True - except OSError: - return False - - -def get_loopback_ip() -> str: - loopback_ip = envs.VLLM_LOOPBACK_IP - if loopback_ip: - return loopback_ip - - # VLLM_LOOPBACK_IP is not set, try to get it based on network interface - - if test_loopback_bind("127.0.0.1", socket.AF_INET): - return "127.0.0.1" - elif test_loopback_bind("::1", socket.AF_INET6): - return "::1" - else: - raise RuntimeError( - "Neither 127.0.0.1 nor ::1 are bound to a local interface. " - "Set the VLLM_LOOPBACK_IP environment variable explicitly." - ) - - -def is_valid_ipv6_address(address: str) -> bool: - try: - ipaddress.IPv6Address(address) - return True - except ValueError: - return False - - -def split_host_port(host_port: str) -> tuple[str, int]: - # ipv6 - if host_port.startswith("["): - host, port = host_port.rsplit("]", 1) - host = host[1:] - port = port.split(":")[1] - return host, int(port) - else: - host, port = host_port.split(":") - return host, int(port) - - -def join_host_port(host: str, port: int) -> str: - if is_valid_ipv6_address(host): - return f"[{host}]:{port}" - else: - return f"{host}:{port}" - - -def get_distributed_init_method(ip: str, port: int) -> str: - return get_tcp_uri(ip, port) - - -def get_tcp_uri(ip: str, port: int) -> str: - if is_valid_ipv6_address(ip): - return f"tcp://[{ip}]:{port}" - else: - return f"tcp://{ip}:{port}" - - -def get_open_zmq_ipc_path() -> str: - base_rpc_path = envs.VLLM_RPC_BASE_PATH - return f"ipc://{base_rpc_path}/{uuid4()}" - - -def get_open_zmq_inproc_path() -> str: - return f"inproc://{uuid4()}" - - -def get_open_port() -> int: - """ - Get an open port for the vLLM process to listen on. - An edge case to handle, is when we run data parallel, - we need to avoid ports that are potentially used by - the data parallel master process. - Right now we reserve 10 ports for the data parallel master - process. Currently it uses 2 ports. - """ - if "VLLM_DP_MASTER_PORT" in os.environ: - dp_master_port = envs.VLLM_DP_MASTER_PORT - reserved_port_range = range(dp_master_port, dp_master_port + 10) - while True: - candidate_port = _get_open_port() - if candidate_port not in reserved_port_range: - return candidate_port - return _get_open_port() - - -def get_open_ports_list(count: int = 5) -> list[int]: - """Get a list of open ports.""" - ports = set[int]() - while len(ports) < count: - ports.add(get_open_port()) - return list(ports) - - -def _get_open_port() -> int: - port = envs.VLLM_PORT - if port is not None: - while True: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", port)) - return port - except OSError: - port += 1 # Increment port number if already in use - logger.info("Port %d is already in use, trying port %d", port - 1, port) - # try ipv4 - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - except OSError: - # try ipv6 - with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - - -def find_process_using_port(port: int) -> psutil.Process | None: - # TODO: We can not check for running processes with network - # port on macOS. Therefore, we can not have a full graceful shutdown - # of vLLM. For now, let's not look for processes in this case. - # Ref: https://www.florianreinhard.de/accessdenied-in-psutil/ - if sys.platform.startswith("darwin"): - return None - - our_pid = os.getpid() - for conn in psutil.net_connections(): - if conn.laddr.port == port and (conn.pid is not None and conn.pid != our_pid): - try: - return psutil.Process(conn.pid) - except psutil.NoSuchProcess: - return None - return None - - -def update_environment_variables(envs: dict[str, str]): - for k, v in envs.items(): - if k in os.environ and os.environ[k] != v: - logger.warning( - "Overwriting environment variable %s from '%s' to '%s'", - k, - os.environ[k], - v, - ) - os.environ[k] = v - - -def chunk_list(lst: list[T], chunk_size: int): - """Yield successive chunk_size chunks from lst.""" - for i in range(0, len(lst), chunk_size): - yield lst[i : i + chunk_size] - - def cdiv(a: int, b: int) -> int: """Ceiling division.""" return -(a // -b) @@ -750,344 +166,6 @@ def round_down(x: int, y: int) -> int: return (x // y) * y -def _generate_random_fp8( - tensor: torch.Tensor, - low: float, - high: float, -) -> None: - # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type, - # it may occur Inf or NaN if we directly use torch.randint - # to generate random data for fp8 data. - # For example, s.11111.00 in fp8e5m2 format represents Inf. - # | E4M3 | E5M2 - # -----|-------------|------------------- - # Inf | N/A | s.11111.00 - # NaN | s.1111.111 | s.11111.{01,10,11} - from vllm import _custom_ops as ops - - tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) - tensor_tmp.uniform_(low, high) - ops.convert_fp8(tensor, tensor_tmp) - del tensor_tmp - - -def get_kv_cache_torch_dtype( - cache_dtype: Union[str, torch.dtype] | None, - model_dtype: Union[str, torch.dtype] | None = None, -) -> torch.dtype: - if isinstance(cache_dtype, str): - if cache_dtype == "auto": - if isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: - torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] - elif isinstance(model_dtype, torch.dtype): - torch_dtype = model_dtype - else: - raise ValueError(f"Invalid model dtype: {model_dtype}") - elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE: - torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] - else: - raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") - elif isinstance(cache_dtype, torch.dtype): - torch_dtype = cache_dtype - else: - raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") - return torch_dtype - - -def create_kv_caches_with_random_flash( - num_blocks: int, - block_size: int, - num_layers: int, - num_heads: int, - head_size: int, - cache_dtype: Union[str, torch.dtype] | None, - model_dtype: Union[str, torch.dtype] | None = None, - seed: int | None = None, - device: str | None = "cuda", - cache_layout: str | None = "NHD", -) -> tuple[list[torch.Tensor], list[torch.Tensor]]: - from vllm.platforms import current_platform - - current_platform.seed_everything(seed) - - torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) - generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) - assert cache_layout in ("NHD", "HND") - stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4) - - kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] for i in stride_order) - scale = head_size**-0.5 - - key_caches: list[torch.Tensor] = [] - value_caches: list[torch.Tensor] = [] - - for _ in range(num_layers): - key_value_cache = torch.empty( - size=kv_cache_allocation_shape, dtype=torch_dtype, device=device - ).permute(*stride_order) - if cache_dtype in ["auto", "half", "bfloat16", "float"]: - key_value_cache.uniform_(-scale, scale) - elif cache_dtype == "fp8": - _generate_random_fp8(key_value_cache, -scale, scale) - else: - raise ValueError(f"Does not support key cache of type {cache_dtype}") - key_caches.append(key_value_cache[:, 0]) - value_caches.append(key_value_cache[:, 1]) - return key_caches, value_caches - - -def create_kv_caches_with_random( - num_blocks: int, - block_size: int, - num_layers: int, - num_heads: int, - head_size: int, - cache_dtype: Union[str, torch.dtype] | None, - model_dtype: Union[str, torch.dtype] | None = None, - seed: int | None = None, - device: str | None = "cuda", -) -> tuple[list[torch.Tensor], list[torch.Tensor]]: - if cache_dtype == "fp8" and head_size % 16: - raise ValueError( - f"Does not support key cache of type fp8 with head_size {head_size}" - ) - from vllm.platforms import current_platform - - current_platform.seed_everything(seed) - - torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) - - scale = head_size**-0.5 - x = 16 // torch.tensor([], dtype=torch_dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) - key_caches: list[torch.Tensor] = [] - for _ in range(num_layers): - key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device) - if cache_dtype in ["auto", "half", "bfloat16", "float"]: - key_cache.uniform_(-scale, scale) - elif cache_dtype == "fp8": - _generate_random_fp8(key_cache, -scale, scale) - else: - raise ValueError(f"Does not support key cache of type {cache_dtype}") - key_caches.append(key_cache) - - value_cache_shape = (num_blocks, num_heads, head_size, block_size) - value_caches: list[torch.Tensor] = [] - for _ in range(num_layers): - value_cache = torch.empty( - size=value_cache_shape, dtype=torch_dtype, device=device - ) - if cache_dtype in ["auto", "half", "bfloat16", "float"]: - value_cache.uniform_(-scale, scale) - elif cache_dtype == "fp8": - _generate_random_fp8(value_cache, -scale, scale) - else: - raise ValueError(f"Does not support value cache of type {cache_dtype}") - value_caches.append(value_cache) - return key_caches, value_caches - - -@cache -def is_pin_memory_available() -> bool: - from vllm.platforms import current_platform - - return current_platform.is_pin_memory_available() - - -@cache -def is_uva_available() -> bool: - """Check if Unified Virtual Addressing (UVA) is available.""" - # UVA requires pinned memory. - # TODO: Add more requirements for UVA if needed. - return is_pin_memory_available() - - -class DeviceMemoryProfiler: - def __init__(self, device: torch.types.Device | None = None): - self.device = device - - def current_memory_usage(self) -> float: - # Return the memory usage in bytes. - from vllm.platforms import current_platform - - gc.collect() - return current_platform.get_current_memory_usage(self.device) - - def __enter__(self): - self.initial_memory = self.current_memory_usage() - # This allows us to call methods of the context manager if needed - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.final_memory = self.current_memory_usage() - self.consumed_memory = self.final_memory - self.initial_memory - - # Force garbage collection - gc.collect() - - -def make_ndarray_with_pad( - x: list[list[T]], - pad: T, - dtype: npt.DTypeLike, - *, - max_len: int | None = None, -) -> npt.NDArray: - """ - Make a padded array from 2D inputs. - - The padding is applied to the end of each inner list until it reaches - `max_len`. - """ - if max_len is None: - # Unlike for most functions, map is faster than a genexpr over `len` - max_len = max(map(len, x), default=0) - - padded_x = np.full((len(x), max_len), pad, dtype=dtype) - for ind, blocktb in enumerate(x): - assert len(blocktb) <= max_len - padded_x[ind, : len(blocktb)] = blocktb - - return padded_x - - -def make_tensor_with_pad( - x: list[list[T]], - pad: T, - dtype: torch.dtype, - *, - max_len: int | None = None, - device: Union[str, torch.device] | None = None, - pin_memory: bool = False, -) -> torch.Tensor: - """ - Make a padded tensor from 2D inputs. - - The padding is applied to the end of each inner list until it reaches - `max_len`. - """ - np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] - padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len) - - tensor = torch.from_numpy(padded_x).to(device) - if pin_memory: - tensor = tensor.pin_memory() - - return tensor - - -def async_tensor_h2d( - data: list, - dtype: torch.dtype, - target_device: Union[str, torch.device], - pin_memory: bool, -) -> torch.Tensor: - """Asynchronously create a tensor and copy it from host to device.""" - t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") - return t.to(device=target_device, non_blocking=True) - - -def get_dtype_size(dtype: torch.dtype) -> int: - """Get the size of the data type in bytes.""" - return torch.tensor([], dtype=dtype).element_size() - - -# bool = 0, int = 1, float = 2, complex = 3 -def _get_precision_level(dtype: torch.dtype) -> int: - # NOTE: Complex dtypes return `is_floating_point=False` - return (dtype != torch.bool) + dtype.is_floating_point + dtype.is_complex * 2 - - -def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): - """ - Test whether it is lossless to cast a tensor from - `src_dtype` to `tgt_dtype`. - """ - if src_dtype == tgt_dtype: - return True - - src_level = _get_precision_level(src_dtype) - tgt_level = _get_precision_level(tgt_dtype) - - if src_level < tgt_level: - return True - if src_level > tgt_level: - return False - - # Compare integral types - if not src_dtype.is_floating_point and not src_dtype.is_complex: - src_info = torch.iinfo(src_dtype) - tgt_info = torch.iinfo(tgt_dtype) - return src_info.min >= tgt_info.min and src_info.max <= tgt_info.max - - # Compare floating-point types - src_info = torch.finfo(src_dtype) - tgt_info = torch.finfo(tgt_dtype) - return ( - src_info.min >= tgt_info.min - and src_info.max <= tgt_info.max - and src_info.resolution >= tgt_info.resolution - ) - - -def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): - """ - Get the common `dtype` where all of the other `dtypes` can be - cast to it without losing any information. - """ - return max( - dtypes, - key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes), - ) - - -def as_list(maybe_list: Iterable[T]) -> list[T]: - """Convert iterable to list, unless it's already a list.""" - return maybe_list if isinstance(maybe_list, list) else list(maybe_list) - - -def as_iter(obj: Union[T, Iterable[T]]) -> Iterable[T]: - if isinstance(obj, str) or not isinstance(obj, Iterable): - return [obj] # type: ignore[list-item] - return obj - - -# `collections` helpers -def is_list_of( - value: object, - typ: Union[type[T], tuple[type[T], ...]], - *, - check: Literal["first", "all"] = "first", -) -> TypeIs[list[T]]: - if not isinstance(value, list): - return False - - if check == "first": - return len(value) == 0 or isinstance(value[0], typ) - elif check == "all": - return all(isinstance(v, typ) for v in value) - - assert_never(check) - - -def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]: - """Flatten a list of lists to a single list.""" - return [item for sublist in lists for item in sublist] - - -def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]): - """ - Unlike [`itertools.groupby`][], groups are not broken by - non-contiguous data. - """ - groups = defaultdict[_K, list[_V]](list) - - for value in values: - groups[key(value)].append(value) - - return groups.items() - - # TODO: This function can be removed if transformer_modules classes are # serialized by value when communicating between processes def init_cached_hf_modules() -> None: @@ -1099,146 +177,6 @@ def init_cached_hf_modules() -> None: init_hf_modules() -@cache -def find_library(lib_name: str) -> str: - """ - Find the library file in the system. - `lib_name` is full filename, with both prefix and suffix. - This function resolves `lib_name` to the full path of the library. - """ - # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa - # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard - # `/sbin/ldconfig` should exist in all Linux systems. - # `/sbin/ldconfig` searches the library in the system - libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() - # each line looks like the following: - # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1 - locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line] - # `LD_LIBRARY_PATH` searches the library in the user-defined paths - env_ld_library_path = envs.LD_LIBRARY_PATH - if not locs and env_ld_library_path: - locs = [ - os.path.join(dir, lib_name) - for dir in env_ld_library_path.split(":") - if os.path.exists(os.path.join(dir, lib_name)) - ] - if not locs: - raise ValueError(f"Cannot find {lib_name} in the system.") - return locs[0] - - -def find_nccl_library() -> str: - """ - We either use the library file specified by the `VLLM_NCCL_SO_PATH` - environment variable, or we find the library file brought by PyTorch. - After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be - found by `ctypes` automatically. - """ - so_file = envs.VLLM_NCCL_SO_PATH - - # manually load the nccl library - if so_file: - logger.info( - "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file - ) - else: - if torch.version.cuda is not None: - so_file = "libnccl.so.2" - elif torch.version.hip is not None: - so_file = "librccl.so.1" - else: - raise ValueError("NCCL only supports CUDA and ROCm backends.") - logger.info("Found nccl from library %s", so_file) - return so_file - - -def find_nccl_include_paths() -> list[str] | None: - """ - We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH` - environment variable, or we find the library file brought by - nvidia-nccl-cuXX. load_inline by default uses - torch.utils.cpp_extension.include_paths - """ - paths: list[str] = [] - inc = envs.VLLM_NCCL_INCLUDE_PATH - if inc and os.path.isdir(inc): - paths.append(inc) - - try: - import importlib.util - - spec = importlib.util.find_spec("nvidia.nccl") - if spec and getattr(spec, "submodule_search_locations", None): - for loc in spec.submodule_search_locations: - inc_dir = os.path.join(loc, "include") - if os.path.exists(os.path.join(inc_dir, "nccl.h")): - paths.append(inc_dir) - except Exception: - pass - - seen = set() - out: list[str] = [] - for p in paths: - if p and p not in seen: - out.append(p) - seen.add(p) - return out or None - - -prev_set_stream = torch.cuda.set_stream - -_current_stream_tls = threading.local() - - -def _patched_set_stream(stream: torch.cuda.Stream) -> None: - _current_stream_tls.value = stream - prev_set_stream(stream) - - -torch.cuda.set_stream = _patched_set_stream - - -class _StreamPlaceholder: - def __init__(self): - self.synchronize = lambda: None - - -def current_stream() -> torch.cuda.Stream: - """ - replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`. - it turns out that `torch.cuda.current_stream()` is quite expensive, - as it will construct a new stream object at each call. - here we patch `torch.cuda.set_stream` to keep track of the current stream - directly, so that we can avoid calling `torch.cuda.current_stream()`. - - the underlying hypothesis is that we do not call `torch._C._cuda_setStream` - from C/C++ code. - """ - from vllm.platforms import current_platform - - if not hasattr(_current_stream_tls, "value") or _current_stream_tls.value is None: - # when this function is called before any stream is set, - # we return the default stream. - # On ROCm using the default 0 stream in combination with RCCL - # is hurting performance. Therefore creating a dedicated stream - # per process - if current_platform.is_rocm(): - # torch.cuda.set_stream here is the alias of _pathed_set_stream - torch.cuda.set_stream(torch.cuda.Stream()) - elif current_platform.is_cpu(): - _current_stream_tls.value = _StreamPlaceholder() - else: - current_stream = current_platform.current_stream - if current_stream is not None: - _current_stream_tls.value = current_stream() - else: - raise ValueError( - "Fail to set current stream, current platform " - "may not support current_stream with torch API" - ) - return _current_stream_tls.value - - def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None: """Set up function tracing for the current thread, if enabled via the VLLM_TRACE_FUNCTION environment variable @@ -1260,162 +198,6 @@ def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None: enable_trace_function_call(log_path) -# `functools` helpers -def identity(value: T, **kwargs) -> T: - """Returns the first provided value.""" - return value - - -F = TypeVar("F", bound=Callable[..., Any]) - - -def deprecate_args( - start_index: int, - is_deprecated: Union[bool, Callable[[], bool]] = True, - additional_message: str | None = None, -) -> Callable[[F], F]: - if not callable(is_deprecated): - is_deprecated = partial(identity, is_deprecated) - - def wrapper(fn: F) -> F: - params = inspect.signature(fn).parameters - pos_types = ( - inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - ) - pos_kws = [kw for kw, param in params.items() if param.kind in pos_types] - - @wraps(fn) - def inner(*args, **kwargs): - if is_deprecated(): - deprecated_args = pos_kws[start_index : len(args)] - if deprecated_args: - msg = ( - f"The positional arguments {deprecated_args} are " - "deprecated and will be removed in a future update." - ) - if additional_message is not None: - msg += f" {additional_message}" - - warnings.warn( - DeprecationWarning(msg), - stacklevel=3, # The inner function takes up one level - ) - - return fn(*args, **kwargs) - - return inner # type: ignore - - return wrapper - - -def deprecate_kwargs( - *kws: str, - is_deprecated: Union[bool, Callable[[], bool]] = True, - additional_message: str | None = None, -) -> Callable[[F], F]: - deprecated_kws = set(kws) - - if not callable(is_deprecated): - is_deprecated = partial(identity, is_deprecated) - - def wrapper(fn: F) -> F: - @wraps(fn) - def inner(*args, **kwargs): - if is_deprecated(): - deprecated_kwargs = kwargs.keys() & deprecated_kws - if deprecated_kwargs: - msg = ( - f"The keyword arguments {deprecated_kwargs} are " - "deprecated and will be removed in a future update." - ) - if additional_message is not None: - msg += f" {additional_message}" - - warnings.warn( - DeprecationWarning(msg), - stacklevel=3, # The inner function takes up one level - ) - - return fn(*args, **kwargs) - - return inner # type: ignore - - return wrapper - - -@lru_cache(maxsize=8) -def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int: - # Note: cuda_visible_devices is not used, but we keep it as an argument for - # LRU Cache purposes. - - # Code below is based on - # https://github.com/pytorch/pytorch/blob/ - # c1cd946818442aca8c7f812b16d187ce1586c3bc/ - # torch/cuda/__init__.py#L831C1-L831C17 - import torch.cuda - import torch.version - - from vllm.platforms import current_platform - - if not torch.cuda._is_compiled(): - return 0 - if current_platform.is_rocm(): - # ROCm uses amdsmi instead of nvml for stateless device count - # This requires a sufficiently modern version of Torch 2.4.0 - raw_count = ( - torch.cuda._device_count_amdsmi() - if (hasattr(torch.cuda, "_device_count_amdsmi")) - else -1 - ) - else: - raw_count = torch.cuda._device_count_nvml() - r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count - return r - - -def cuda_device_count_stateless() -> int: - """Get number of CUDA devices, caching based on the value of - CUDA_VISIBLE_DEVICES at the time of call. - - This should be used instead of torch.cuda.device_count() - unless CUDA_VISIBLE_DEVICES has already been set to the desired - value.""" - - # This can be removed and simply replaced with torch.cuda.get_device_count - # after https://github.com/pytorch/pytorch/pull/122815 is released. - return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) - - -def cuda_is_initialized() -> bool: - """Check if CUDA is initialized.""" - if not torch.cuda._is_compiled(): - return False - return torch.cuda.is_initialized() - - -def xpu_is_initialized() -> bool: - """Check if XPU is initialized.""" - if not torch.xpu._is_compiled(): - return False - return torch.xpu.is_initialized() - - -def cuda_get_device_properties( - device, names: Sequence[str], init_cuda=False -) -> tuple[Any, ...]: - """Get specified CUDA device property values without initializing CUDA in - the current process.""" - if init_cuda or cuda_is_initialized(): - props = torch.cuda.get_device_properties(device) - return tuple(getattr(props, name) for name in names) - - # Run in subprocess to avoid initializing CUDA as a side effect. - mp_ctx = multiprocessing.get_context("fork") - with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor: - return executor.submit(cuda_get_device_properties, device, names, True).result() - - def weak_bind( bound_method: Callable[..., Any], ) -> Callable[..., None]: @@ -1432,21 +214,6 @@ def weak_bound(*args, **kwargs) -> None: return weak_bound -def run_once(f: Callable[P, None]) -> Callable[P, None]: - def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: - if wrapper.has_run: # type: ignore[attr-defined] - return - - with wrapper.lock: # type: ignore[attr-defined] - if not wrapper.has_run: # type: ignore[attr-defined] - wrapper.has_run = True # type: ignore[attr-defined] - return f(*args, **kwargs) - - wrapper.has_run = False # type: ignore[attr-defined] - wrapper.lock = threading.Lock() # type: ignore[attr-defined] - return wrapper - - class StoreBoolean(Action): def __call__(self, parser, namespace, values, option_string=None): if values.lower() == "true": @@ -1690,16 +457,16 @@ def repl(match: re.Match) -> str: elif arg.startswith("-O") and arg != "-O" and arg[2] != ".": # allow -O flag to be used without space, e.g. -O3 or -Odecode # -O.<...> handled later - # also handle -O=<level> here - level = arg[3:] if arg[2] == "=" else arg[2:] - processed_args.append(f"-O.level={level}") + # also handle -O=<mode> here + mode = arg[3:] if arg[2] == "=" else arg[2:] + processed_args.append(f"-O.mode={mode}") elif ( arg == "-O" and i + 1 < len(args) and args[i + 1] in {"0", "1", "2", "3"} ): - # Convert -O <n> to -O.level <n> - processed_args.append("-O.level") + # Convert -O <n> to -O.mode <n> + processed_args.append("-O.mode") else: processed_args.append(arg) @@ -1897,7 +664,7 @@ def load_config_file(self, file_path: str) -> list[str]: # only expecting a flat dictionary of atomic types processed_args: list[str] = [] - config: dict[str, Union[int, str]] = {} + config: dict[str, int | str] = {} try: with open(file_path) as config_file: config = yaml.safe_load(config_file) @@ -1929,149 +696,6 @@ def load_config_file(self, file_path: str) -> list[str]: return processed_args -async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs): - """Utility function to run async task in a lock""" - async with lock: - return await task(*args, **kwargs) - - -@lru_cache -def supports_kw( - callable: Callable[..., object], - kw_name: str, - *, - requires_kw_only: bool = False, - allow_var_kwargs: bool = True, -) -> bool: - """Check if a keyword is a valid kwarg for a callable; if requires_kw_only - disallows kwargs names that can also be positional arguments. - """ - params = inspect.signature(callable).parameters - if not params: - return False - - param_val = params.get(kw_name) - - # Types where the it may be valid, i.e., explicitly defined & nonvariadic - passable_kw_types = set( - ( - inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY, - ) - ) - - if param_val: - is_sig_param = param_val.kind in passable_kw_types - # We want kwargs only, but this is passable as a positional arg - if ( - requires_kw_only - and is_sig_param - and param_val.kind != inspect.Parameter.KEYWORD_ONLY - ): - return False - if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or ( - not requires_kw_only and is_sig_param - ): - return True - - # If we're okay with var-kwargs, it's supported as long as - # the kw_name isn't something like *args, **kwargs - if allow_var_kwargs: - # Get the last param; type is ignored here because params is a proxy - # mapping, but it wraps an ordered dict, and they appear in order. - # Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters - last_param = params[next(reversed(params))] # type: ignore - return ( - last_param.kind == inspect.Parameter.VAR_KEYWORD - and last_param.name != kw_name - ) - - return False - - -def get_allowed_kwarg_only_overrides( - callable: Callable[..., object], - overrides: Mapping[str, object] | None, - *, - requires_kw_only: bool = True, - allow_var_kwargs: bool = False, -) -> dict[str, Any]: - """ - Given a callable which has one or more keyword only params and a dict - mapping param names to values, drop values that can be not be kwarg - expanded to overwrite one or more keyword-only args. This is used in a - few places to handle custom processor overrides for multimodal models, - e.g., for profiling when processor options provided by the user - may affect the number of mm tokens per instance. - - Args: - callable: Callable which takes 0 or more keyword only arguments. - If None is provided, all overrides names are allowed. - overrides: Potential overrides to be used when invoking the callable. - allow_var_kwargs: Allows overrides that are expandable for var kwargs. - - Returns: - Dictionary containing the kwargs to be leveraged which may be used - to overwrite one or more keyword only arguments when invoking the - callable. - """ - if not overrides: - return {} - - # Drop any mm_processor_kwargs provided by the user that - # are not kwargs, unless it can fit it var_kwargs param - filtered_overrides = { - kwarg_name: val - for kwarg_name, val in overrides.items() - if supports_kw( - callable, - kwarg_name, - requires_kw_only=requires_kw_only, - allow_var_kwargs=allow_var_kwargs, - ) - } - - # If anything is dropped, log a warning - dropped_keys = overrides.keys() - filtered_overrides.keys() - if dropped_keys: - if requires_kw_only: - logger.warning( - "The following intended overrides are not keyword-only args " - "and will be dropped: %s", - dropped_keys, - ) - else: - logger.warning( - "The following intended overrides are not keyword args " - "and will be dropped: %s", - dropped_keys, - ) - - return filtered_overrides - - -# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. -# In particular, the FakeScalarType is not supported for earlier versions of -# PyTorch which breaks dynamo for any ops registered using ScalarType. -def supports_dynamo() -> bool: - base_torch_version = Version(Version(torch.__version__).base_version) - return base_torch_version >= Version("2.4.0") - - -# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform -def supports_xccl() -> bool: - return ( - is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available() - ) - - -# Some backends use pytorch version < 2.4.0 which doesn't -# support `torch.library.custom_op`. -def supports_custom_op() -> bool: - return hasattr(torch.library, "custom_op") - - class AtomicCounter: """An atomic, thread-safe counter""" @@ -2097,417 +721,6 @@ def value(self): return self._value -# Adapted from: https://stackoverflow.com/a/47212782/5082708 -class LazyDict(Mapping[str, T], Generic[T]): - def __init__(self, factory: dict[str, Callable[[], T]]): - self._factory = factory - self._dict: dict[str, T] = {} - - def __getitem__(self, key: str) -> T: - if key not in self._dict: - if key not in self._factory: - raise KeyError(key) - self._dict[key] = self._factory[key]() - return self._dict[key] - - def __setitem__(self, key: str, value: Callable[[], T]): - self._factory[key] = value - - def __iter__(self): - return iter(self._factory) - - def __len__(self): - return len(self._factory) - - -class ClassRegistry(UserDict[type[T], _V]): - def __getitem__(self, key: type[T]) -> _V: - for cls in key.mro(): - if cls in self.data: - return self.data[cls] - - raise KeyError(key) - - def __contains__(self, key: object) -> bool: - return self.contains(key) - - def contains(self, key: object, *, strict: bool = False) -> bool: - if not isinstance(key, type): - return False - - if strict: - return key in self.data - - return any(cls in self.data for cls in key.mro()) - - -def weak_ref_tensor(tensor: Any) -> Any: - """ - Create a weak reference to a tensor. - The new tensor will share the same data as the original tensor, - but will not keep the original tensor alive. - """ - if isinstance(tensor, torch.Tensor): - return torch.ops._C.weak_ref_tensor(tensor) - else: - return tensor - - -def weak_ref_tensors( - tensors: Union[ - torch.Tensor, list[torch.Tensor], tuple[torch.Tensor], IntermediateTensors - ], -) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: - """ - Convenience function to create weak references to tensors, - for single tensor, list of tensors or tuple of tensors. - """ - if isinstance(tensors, torch.Tensor): - return weak_ref_tensor(tensors) - if isinstance(tensors, list): - return [weak_ref_tensor(t) for t in tensors] - if isinstance(tensors, tuple): - return tuple(weak_ref_tensor(t) for t in tensors) - - # For IntermediateTensors used in pipeline parallelism - from vllm.sequence import IntermediateTensors - - if isinstance(tensors, IntermediateTensors): - ret = IntermediateTensors( - {key: weak_ref_tensor(val) for key, val in tensors.tensors.items()} - ) - return ret - raise ValueError("Invalid type for tensors") - - -def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor: - """ - Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA). - """ - assert cpu_tensor.is_pinned(), "CPU tensor must be pinned" - return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) - - -def import_from_path(module_name: str, file_path: Union[str, os.PathLike]): - """ - Import a Python file according to its file path. - - Based on the official recipe: - https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly - """ - spec = importlib.util.spec_from_file_location(module_name, file_path) - if spec is None: - raise ModuleNotFoundError(f"No module named '{module_name}'") - - assert spec.loader is not None - - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - return module - - -@cache -def get_vllm_optional_dependencies(): - metadata = importlib.metadata.metadata("vllm") - requirements = metadata.get_all("Requires-Dist", []) - extras = metadata.get_all("Provides-Extra", []) - - return { - extra: [ - re.split(r";|>=|<=|==", req)[0] - for req in requirements - if req.endswith(f'extra == "{extra}"') - ] - for extra in extras - } - - -class _PlaceholderBase: - """ - Disallows downstream usage of placeholder modules. - - We need to explicitly override each dunder method because - [`__getattr__`][vllm.utils._PlaceholderBase.__getattr__] - is not called when they are accessed. - - Info: - [Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup) - """ - - def __getattr__(self, key: str) -> Never: - """ - The main class should implement this to throw an error - for attribute accesses representing downstream usage. - """ - raise NotImplementedError - - # [Basic customization] - - def __lt__(self, other: object): - return self.__getattr__("__lt__") - - def __le__(self, other: object): - return self.__getattr__("__le__") - - def __eq__(self, other: object): - return self.__getattr__("__eq__") - - def __ne__(self, other: object): - return self.__getattr__("__ne__") - - def __gt__(self, other: object): - return self.__getattr__("__gt__") - - def __ge__(self, other: object): - return self.__getattr__("__ge__") - - def __hash__(self): - return self.__getattr__("__hash__") - - def __bool__(self): - return self.__getattr__("__bool__") - - # [Callable objects] - - def __call__(self, *args: object, **kwargs: object): - return self.__getattr__("__call__") - - # [Container types] - - def __len__(self): - return self.__getattr__("__len__") - - def __getitem__(self, key: object): - return self.__getattr__("__getitem__") - - def __setitem__(self, key: object, value: object): - return self.__getattr__("__setitem__") - - def __delitem__(self, key: object): - return self.__getattr__("__delitem__") - - # __missing__ is optional according to __getitem__ specification, - # so it is skipped - - # __iter__ and __reversed__ have a default implementation - # based on __len__ and __getitem__, so they are skipped. - - # [Numeric Types] - - def __add__(self, other: object): - return self.__getattr__("__add__") - - def __sub__(self, other: object): - return self.__getattr__("__sub__") - - def __mul__(self, other: object): - return self.__getattr__("__mul__") - - def __matmul__(self, other: object): - return self.__getattr__("__matmul__") - - def __truediv__(self, other: object): - return self.__getattr__("__truediv__") - - def __floordiv__(self, other: object): - return self.__getattr__("__floordiv__") - - def __mod__(self, other: object): - return self.__getattr__("__mod__") - - def __divmod__(self, other: object): - return self.__getattr__("__divmod__") - - def __pow__(self, other: object, modulo: object = ...): - return self.__getattr__("__pow__") - - def __lshift__(self, other: object): - return self.__getattr__("__lshift__") - - def __rshift__(self, other: object): - return self.__getattr__("__rshift__") - - def __and__(self, other: object): - return self.__getattr__("__and__") - - def __xor__(self, other: object): - return self.__getattr__("__xor__") - - def __or__(self, other: object): - return self.__getattr__("__or__") - - # r* and i* methods have lower priority than - # the methods for left operand so they are skipped - - def __neg__(self): - return self.__getattr__("__neg__") - - def __pos__(self): - return self.__getattr__("__pos__") - - def __abs__(self): - return self.__getattr__("__abs__") - - def __invert__(self): - return self.__getattr__("__invert__") - - # __complex__, __int__ and __float__ have a default implementation - # based on __index__, so they are skipped. - - def __index__(self): - return self.__getattr__("__index__") - - def __round__(self, ndigits: object = ...): - return self.__getattr__("__round__") - - def __trunc__(self): - return self.__getattr__("__trunc__") - - def __floor__(self): - return self.__getattr__("__floor__") - - def __ceil__(self): - return self.__getattr__("__ceil__") - - # [Context managers] - - def __enter__(self): - return self.__getattr__("__enter__") - - def __exit__(self, *args: object, **kwargs: object): - return self.__getattr__("__exit__") - - -class PlaceholderModule(_PlaceholderBase): - """ - A placeholder object to use when a module does not exist. - - This enables more informative errors when trying to access attributes - of a module that does not exist. - """ - - def __init__(self, name: str) -> None: - super().__init__() - - # Apply name mangling to avoid conflicting with module attributes - self.__name = name - - def placeholder_attr(self, attr_path: str): - return _PlaceholderModuleAttr(self, attr_path) - - def __getattr__(self, key: str): - name = self.__name - - try: - importlib.import_module(name) - except ImportError as exc: - for extra, names in get_vllm_optional_dependencies().items(): - if name in names: - msg = f"Please install vllm[{extra}] for {extra} support" - raise ImportError(msg) from exc - - raise exc - - raise AssertionError( - "PlaceholderModule should not be used " - "when the original module can be imported" - ) - - -class _PlaceholderModuleAttr(_PlaceholderBase): - def __init__(self, module: PlaceholderModule, attr_path: str) -> None: - super().__init__() - - # Apply name mangling to avoid conflicting with module attributes - self.__module = module - self.__attr_path = attr_path - - def placeholder_attr(self, attr_path: str): - return _PlaceholderModuleAttr(self.__module, f"{self.__attr_path}.{attr_path}") - - def __getattr__(self, key: str): - getattr(self.__module, f"{self.__attr_path}.{key}") - - raise AssertionError( - "PlaceholderModule should not be used " - "when the original module can be imported" - ) - - -# create a library to hold the custom op -vllm_lib = Library("vllm", "FRAGMENT") # noqa - - -def direct_register_custom_op( - op_name: str, - op_func: Callable, - mutates_args: list[str] | None = None, - fake_impl: Callable | None = None, - target_lib: Library | None = None, - dispatch_key: str | None = None, - tags: tuple[torch.Tag, ...] = (), -): - """ - `torch.library.custom_op` can have significant overhead because it - needs to consider complicated dispatching logic. This function - directly registers a custom op and dispatches it to the CUDA backend. - See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 - for more details. - - By default, the custom op is registered to the vLLM library. If you - want to register it to a different library, you can pass the library - object to the `target_lib` argument. - - IMPORTANT: the lifetime of the operator is tied to the lifetime of the - library object. If you want to bind the operator to a different library, - make sure the library object is alive when the operator is used. - """ - if not supports_custom_op(): - from vllm.platforms import current_platform - - assert not current_platform.is_cuda_alike(), ( - "cuda platform needs torch>=2.4 to support custom op, " - "chances are you are using an old version of pytorch " - "or a custom build of pytorch. It is recommended to " - "use vLLM in a fresh new environment and let it install " - "the required dependencies." - ) - return - - if mutates_args is None: - mutates_args = [] - - if dispatch_key is None: - from vllm.platforms import current_platform - - dispatch_key = current_platform.dispatch_key - - import torch.library - - if hasattr(torch.library, "infer_schema"): - schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) - else: - # for pytorch 2.4 - import torch._custom_op.impl - - schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) - my_lib = target_lib or vllm_lib - my_lib.define(op_name + schema_str, tags=tags) - my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) - if fake_impl is not None: - my_lib._register_fake(op_name, fake_impl) - - -def resolve_obj_by_qualname(qualname: str) -> Any: - """ - Resolve an object by its fully-qualified class name. - """ - module_name, obj_name = qualname.rsplit(".", 1) - module = importlib.import_module(module_name) - return getattr(module, obj_name) - - def kill_process_tree(pid: int): """ Kills all descendant processes of the given pid by sending SIGKILL. @@ -2533,183 +746,6 @@ def kill_process_tree(pid: int): os.kill(pid, signal.SIGKILL) -@dataclass -class MemorySnapshot: - """Memory snapshot.""" - - torch_peak: int = 0 - free_memory: int = 0 - total_memory: int = 0 - cuda_memory: int = 0 - torch_memory: int = 0 - non_torch_memory: int = 0 - timestamp: float = 0.0 - auto_measure: bool = True - - def __post_init__(self): - if self.auto_measure: - self.measure() - - def measure(self): - from vllm.platforms import current_platform - - # we measure the torch peak memory usage via allocated_bytes, - # rather than `torch.cuda.memory_reserved()` . - # After `torch.cuda.reset_peak_memory_stats()`, - # `torch.cuda.memory_reserved()` will keep growing, and only shrink - # when we call `torch.cuda.empty_cache()` or OOM happens. - self.torch_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0) - - self.free_memory, self.total_memory = torch.cuda.mem_get_info() - shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark - if ( - current_platform.is_cuda() - and current_platform.get_device_capability() in shared_sysmem_device_mem_sms - ): - # On UMA (Orin, Thor and Spark) platform, - # where both CPU and GPU rely on system memory, - # the cudaMemGetInfo function shows the amount of free system memory - # rather than what’s actually available. - # In the case, - # torch.cuda.mem_get_info() only reports "free" memory, - # which can be lower than what is actually - # available due to not including cache memory. - # There’s also a comprehensive reference page - # that explains how you can compute the proper value yourself. - # https://docs.nvidia.com/cuda/cuda-for-tegra-appnote/#estimating-total-allocatable-device-memory-on-an-integrated-gpu-device - self.free_memory = psutil.virtual_memory().available - - self.cuda_memory = self.total_memory - self.free_memory - - # torch.cuda.memory_reserved() is how many bytes - # PyTorch gets from cuda (by calling cudaMalloc, etc.) - # this is used to measure the non-torch memory usage - self.torch_memory = torch.cuda.memory_reserved() - - self.non_torch_memory = self.cuda_memory - self.torch_memory - self.timestamp = time.time() - - def __sub__(self, other: MemorySnapshot) -> MemorySnapshot: - return MemorySnapshot( - torch_peak=self.torch_peak - other.torch_peak, - free_memory=self.free_memory - other.free_memory, - total_memory=self.total_memory - other.total_memory, - cuda_memory=self.cuda_memory - other.cuda_memory, - torch_memory=self.torch_memory - other.torch_memory, - non_torch_memory=self.non_torch_memory - other.non_torch_memory, - timestamp=self.timestamp - other.timestamp, - auto_measure=False, - ) - - -@dataclass -class MemoryProfilingResult: - """Memory profiling result. All numbers are in bytes.""" - - non_kv_cache_memory: int = 0 - torch_peak_increase: int = 0 - non_torch_increase: int = 0 - weights_memory: float = 0 - before_create: MemorySnapshot = field(default_factory=MemorySnapshot) - before_profile: MemorySnapshot = field(default_factory=MemorySnapshot) - after_profile: MemorySnapshot = field(default_factory=MemorySnapshot) - profile_time: float = 0.0 - - def __repr__(self) -> str: - return ( - f"Memory profiling takes {self.profile_time:.2f} seconds. " - f"Total non KV cache memory: " - f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; " - f"torch peak memory increase: " - f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; " - f"non-torch forward increase memory: " - f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; " - f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB." - ) - - -@contextlib.contextmanager -def memory_profiling( - baseline_snapshot: MemorySnapshot, weights_memory: int -) -> Generator[MemoryProfilingResult, None, None]: - """Memory profiling context manager. - baseline_snapshot: the memory snapshot before the current vLLM instance. - weights_memory: memory used by PyTorch when loading the model weights. - Note that, before loading the model weights, we also initialize the device - and distributed environment, which may consume some memory. This part is not - included in the weights_memory because PyTorch does not control it. - - The memory in one GPU can be classified into 3 categories: - 1. memory used by anything other than the current vLLM instance. - 2. memory used by torch in the current vLLM instance. - 3. memory used in the current vLLM instance, but not by torch. - - A quantitive example: - - Before creating the current vLLM instance: - category 1: 1 GiB - category 2: 0 GiB - category 3: 0 GiB - - After creating the current vLLM instance and loading the model, - (i.e. before profiling): - category 1: 1 GiB - category 2: 2 GiB (model weights take 2 GiB) - category 3: 0.5 GiB (memory used by NCCL) - - During profiling (peak): - category 1: 1 GiB - category 2: 4 GiB (peak activation tensors take 2 GiB) - category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) - - After profiling: - category 1: 1 GiB - category 2: 3 GiB (after garbage-collecting activation tensors) - category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) - - In this case, non-kv cache takes 5 GiB in total, including: - a. 2 GiB used by the model weights (category 2) - b. 2 GiB reserved for the peak activation tensors (category 2) - c. 1 GiB used by non-torch components (category 3) - - The memory used for loading weights (a.) is directly given from the argument `weights_memory`. - - The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.). - - The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.). - """ # noqa - gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - - result = MemoryProfilingResult() - - result.before_create = baseline_snapshot - # the part of memory used for holding the model weights - result.weights_memory = weights_memory - - result.before_profile.measure() - - yield result - - gc.collect() - torch.cuda.empty_cache() - - result.after_profile.measure() - - diff_profile = result.after_profile - result.before_profile - diff_from_create = result.after_profile - result.before_create - result.torch_peak_increase = diff_profile.torch_peak - result.non_torch_increase = diff_from_create.non_torch_memory - result.profile_time = diff_profile.timestamp - - non_torch_memory = result.non_torch_increase - peak_activation_memory = result.torch_peak_increase - result.non_kv_cache_memory = ( - non_torch_memory + peak_activation_memory + result.weights_memory - ) # noqa - - # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501 def set_ulimit(target_soft_limit=65535): if sys.platform.startswith("win"): @@ -2742,122 +778,6 @@ def get_exception_traceback(): return err_str -def split_zmq_path(path: str) -> tuple[str, str, str]: - """Split a zmq path into its parts.""" - parsed = urlparse(path) - if not parsed.scheme: - raise ValueError(f"Invalid zmq path: {path}") - - scheme = parsed.scheme - host = parsed.hostname or "" - port = str(parsed.port or "") - - if scheme == "tcp" and not all((host, port)): - # The host and port fields are required for tcp - raise ValueError(f"Invalid zmq path: {path}") - - if scheme != "tcp" and port: - # port only makes sense with tcp - raise ValueError(f"Invalid zmq path: {path}") - - return scheme, host, port - - -def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str: - """Make a ZMQ path from its parts. - - Args: - scheme: The ZMQ transport scheme (e.g. tcp, ipc, inproc). - host: The host - can be an IPv4 address, IPv6 address, or hostname. - port: Optional port number, only used for TCP sockets. - - Returns: - A properly formatted ZMQ path string. - """ - if port is None: - return f"{scheme}://{host}" - if is_valid_ipv6_address(host): - return f"{scheme}://[{host}]:{port}" - return f"{scheme}://{host}:{port}" - - -# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501 -def make_zmq_socket( - ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined] - path: str, - socket_type: Any, - bind: bool | None = None, - identity: bytes | None = None, - linger: int | None = None, -) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] - """Make a ZMQ socket with the proper bind/connect semantics.""" - - mem = psutil.virtual_memory() - socket = ctx.socket(socket_type) - - # Calculate buffer size based on system memory - total_mem = mem.total / 1024**3 - available_mem = mem.available / 1024**3 - # For systems with substantial memory (>32GB total, >16GB available): - # - Set a large 0.5GB buffer to improve throughput - # For systems with less memory: - # - Use system default (-1) to avoid excessive memory consumption - buf_size = int(0.5 * 1024**3) if total_mem > 32 and available_mem > 16 else -1 - - if bind is None: - bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB) - - if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER): - socket.setsockopt(zmq.RCVHWM, 0) - socket.setsockopt(zmq.RCVBUF, buf_size) - - if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER): - socket.setsockopt(zmq.SNDHWM, 0) - socket.setsockopt(zmq.SNDBUF, buf_size) - - if identity is not None: - socket.setsockopt(zmq.IDENTITY, identity) - - if linger is not None: - socket.setsockopt(zmq.LINGER, linger) - - if socket_type == zmq.XPUB: - socket.setsockopt(zmq.XPUB_VERBOSE, True) - - # Determine if the path is a TCP socket with an IPv6 address. - # Enable IPv6 on the zmq socket if so. - scheme, host, _ = split_zmq_path(path) - if scheme == "tcp" and is_valid_ipv6_address(host): - socket.setsockopt(zmq.IPV6, 1) - - if bind: - socket.bind(path) - else: - socket.connect(path) - - return socket - - -@contextlib.contextmanager -def zmq_socket_ctx( - path: str, - socket_type: Any, - bind: bool | None = None, - linger: int = 0, - identity: bytes | None = None, -) -> Iterator[zmq.Socket]: - """Context manager for a ZMQ socket""" - - ctx = zmq.Context() # type: ignore[attr-defined] - try: - yield make_zmq_socket(ctx, path, socket_type, bind=bind, identity=identity) - except KeyboardInterrupt: - logger.debug("Got Keyboard Interrupt.") - - finally: - ctx.destroy(linger=linger) - - def _maybe_force_spawn(): """Check if we need to force the use of the `spawn` multiprocessing start method. @@ -2955,7 +875,7 @@ def bind_kv_cache( def run_method( obj: Any, - method: Union[str, bytes, Callable], + method: str | bytes | Callable, args: tuple[Any], kwargs: dict[str, Any], ) -> Any: @@ -3055,120 +975,6 @@ def wrapped_init(self, *args, **kwargs) -> None: return cls -class LazyLoader(types.ModuleType): - """ - LazyLoader module borrowed from Tensorflow - https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py - with an addition of "module caching". - - Lazily import a module, mainly to avoid pulling in large dependencies. - Modules such as `xgrammar` might do additional side effects, so we - only want to use this when it is needed, delaying all eager effects - """ - - def __init__( - self, - local_name: str, - parent_module_globals: dict[str, Any], - name: str, - ): - self._local_name = local_name - self._parent_module_globals = parent_module_globals - self._module: types.ModuleType | None = None - - super().__init__(str(name)) - - def _load(self) -> types.ModuleType: - # Import the target module and insert it into the parent's namespace - try: - module = importlib.import_module(self.__name__) - self._parent_module_globals[self._local_name] = module - # The additional add to sys.modules - # ensures library is actually loaded. - sys.modules[self._local_name] = module - except ModuleNotFoundError as err: - raise err from None - - # Update this object's dict so that if someone keeps a - # reference to the LazyLoader, lookups are efficient - # (__getattr__ is only called on lookups that fail). - self.__dict__.update(module.__dict__) - return module - - def __getattr__(self, item: Any) -> Any: - if self._module is None: - self._module = self._load() - return getattr(self._module, item) - - def __dir__(self) -> list[str]: - if self._module is None: - self._module = self._load() - return dir(self._module) - - -def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None: - """ - Helper function to swap values for two keys - """ - v1 = obj.get(key1) - v2 = obj.get(key2) - if v1 is not None: - obj[key2] = v1 - else: - obj.pop(key2, None) - if v2 is not None: - obj[key1] = v2 - else: - obj.pop(key1, None) - - -@contextlib.contextmanager -def cprofile_context(save_file: str | None = None): - """Run a cprofile - - Args: - save_file: path to save the profile result. "1" or - None will result in printing to stdout. - """ - import cProfile - - prof = cProfile.Profile() - prof.enable() - - try: - yield - finally: - prof.disable() - if save_file and save_file != "1": - prof.dump_stats(save_file) - else: - prof.print_stats(sort="cumtime") - - -def cprofile(save_file: str | None = None, enabled: bool = True): - """Decorator to profile a Python method using cProfile. - - Args: - save_file: Path to save the profile result. - If "1", None, or "", results will be printed to stdout. - enabled: Set to false to turn this into a no-op - """ - - def decorator(func: Callable): - @wraps(func) - def wrapper(*args, **kwargs): - if not enabled: - # If profiling is disabled, just call the function directly. - return func(*args, **kwargs) - - with cprofile_context(save_file): - return func(*args, **kwargs) - - return wrapper - - return decorator - - # Only relevant for models using ALiBi (e.g, MPT) def check_use_alibi(model_config: ModelConfig) -> bool: cfg = model_config.hf_text_config @@ -3194,186 +1000,6 @@ def check_use_alibi(model_config: ModelConfig) -> bool: ) -def sha256(input: Any) -> bytes: - """Hash any picklable Python object using SHA-256. - - The input is serialized using pickle before hashing, which allows - arbitrary Python objects to be used. Note that this function does - not use a hash seed—if you need one, prepend it explicitly to the input. - - Args: - input: Any picklable Python object. - - Returns: - Bytes representing the SHA-256 hash of the serialized input. - """ - input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) - return hashlib.sha256(input_bytes).digest() - - -def sha256_cbor(input: Any) -> bytes: - """ - Hash objects using CBOR serialization and SHA-256. - - This option is useful for non-Python-dependent serialization and hashing. - - Args: - input: Object to be serialized and hashed. Supported types include - basic Python types and complex structures like lists, tuples, and - dictionaries. - Custom classes must implement CBOR serialization methods. - - Returns: - Bytes representing the SHA-256 hash of the CBOR serialized input. - """ - input_bytes = cbor2.dumps(input, canonical=True) - return hashlib.sha256(input_bytes).digest() - - -def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: - """Get a hash function by name, or raise an error if - the function is not found. - Args: - hash_fn_name: Name of the hash function. - Returns: - A hash function. - """ - if hash_fn_name == "sha256": - return sha256 - if hash_fn_name == "sha256_cbor": - return sha256_cbor - - raise ValueError(f"Unsupported hash function: {hash_fn_name}") - - -def is_torch_equal_or_newer(target: str) -> bool: - """Check if the installed torch version is >= the target version. - - Args: - target: a version string, like "2.6.0". - - Returns: - Whether the condition meets. - """ - try: - return _is_torch_equal_or_newer(str(torch.__version__), target) - except Exception: - # Fallback to PKG-INFO to load the package info, needed by the doc gen. - return Version(importlib.metadata.version("torch")) >= Version(target) - - -# Helper function used in testing. -def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool: - torch_version = version.parse(torch_version) - return torch_version >= version.parse(target) - - -@cache -def _has_module(module_name: str) -> bool: - """Return True if *module_name* can be found in the current environment. - - The result is cached so that subsequent queries for the same module incur - no additional overhead. - """ - return importlib.util.find_spec(module_name) is not None - - -def has_pplx() -> bool: - """Whether the optional `pplx_kernels` package is available.""" - - return _has_module("pplx_kernels") - - -def has_deep_ep() -> bool: - """Whether the optional `deep_ep` package is available.""" - - return _has_module("deep_ep") - - -def has_deep_gemm() -> bool: - """Whether the optional `deep_gemm` package is available.""" - - return _has_module("deep_gemm") - - -def has_triton_kernels() -> bool: - """Whether the optional `triton_kernels` package is available.""" - - return _has_module("triton_kernels") - - -def has_tilelang() -> bool: - """Whether the optional `tilelang` package is available.""" - - return _has_module("tilelang") - - -def set_process_title( - name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX -) -> None: - """ - Set the current process title to a specific name with an - optional suffix. - - Args: - name: The title to assign to the current process. - suffix: An optional suffix to append to the base name. - prefix: A prefix to prepend to the front separated by `::`. - """ - if suffix: - name = f"{name}_{suffix}" - setproctitle.setproctitle(f"{prefix}::{name}") - - -def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: - """Prepend each output line with process-specific prefix""" - - prefix = f"{CYAN}({worker_name} pid={pid}){RESET} " - file_write = file.write - - def write_with_prefix(s: str): - if not s: - return - if file.start_new_line: # type: ignore[attr-defined] - file_write(prefix) - idx = 0 - while (next_idx := s.find("\n", idx)) != -1: - next_idx += 1 - file_write(s[idx:next_idx]) - if next_idx == len(s): - file.start_new_line = True # type: ignore[attr-defined] - return - file_write(prefix) - idx = next_idx - file_write(s[idx:]) - file.start_new_line = False # type: ignore[attr-defined] - - file.start_new_line = True # type: ignore[attr-defined] - file.write = write_with_prefix # type: ignore[method-assign] - - -def decorate_logs(process_name: str | None = None) -> None: - """ - Adds a process-specific prefix to each line of output written to stdout and - stderr. - - This function is intended to be called before initializing the api_server, - engine_core, or worker classes, so that all subsequent output from the - process is prefixed with the process name and PID. This helps distinguish - log output from different processes in multi-process environments. - - Args: - process_name: Optional; the name of the process to use in the prefix. - If not provided, the current process name from the multiprocessing - context is used. - """ - if process_name is None: - process_name = get_mp_context().current_process().name - pid = os.getpid() - _add_prefix(sys.stdout, process_name, pid) - _add_prefix(sys.stderr, process_name, pid) - - def length_from_prompt_token_ids_or_embeds( prompt_token_ids: list[int] | None, prompt_embeds: torch.Tensor | None, @@ -3396,36 +1022,3 @@ def length_from_prompt_token_ids_or_embeds( f" prompt_embeds={prompt_embeds_len}" ) return prompt_token_len - - -@contextlib.contextmanager -def set_env_var(key, value): - old = os.environ.get(key) - os.environ[key] = value - try: - yield - finally: - if old is None: - del os.environ[key] - else: - os.environ[key] = old - - -def unique_filepath(fn: Callable[[int], Path]) -> Path: - """ - unique_filepath returns a unique path by trying - to include an integer in increasing order. - - fn should be a callable that returns a path that - includes the passed int at a fixed location. - - Note: This function has a TOCTOU race condition. - Caller should use atomic operations (e.g., open with 'x' mode) - when creating the file to ensure thread safety. - """ - i = 0 - while True: - p = fn(i) - if not p.exists(): - return p - i += 1 diff --git a/vllm/utils/async_utils.py b/vllm/utils/async_utils.py new file mode 100644 index 000000000000..b6c24e1ceeee --- /dev/null +++ b/vllm/utils/async_utils.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Contains helpers related to asynchronous code. + +This is similar in concept to the `asyncio` module. +""" + +import asyncio +import contextlib +from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task +from collections.abc import AsyncGenerator, Awaitable, Callable +from concurrent.futures import Executor, ThreadPoolExecutor +from functools import partial +from typing import TypeVar + +from transformers.tokenization_utils_base import BatchEncoding +from typing_extensions import ParamSpec + +P = ParamSpec("P") +T = TypeVar("T") + + +class AsyncMicrobatchTokenizer: + """Asynchronous tokenizer with micro-batching. + + Pulls pending encode/decode requests from a queue and batches them + up to reduce overhead. A single-thread ThreadPoolExecutor is used + so the event loop stays responsive. + """ + + def __init__( + self, + tokenizer, + max_batch_size: int = 32, + batch_wait_timeout_s: float = 0.002, + ) -> None: + self.tokenizer = tokenizer + self.max_batch_size = max_batch_size + self.batch_wait_timeout_s = batch_wait_timeout_s + + self._loop = asyncio.get_running_loop() + self._queues: dict[ + tuple, + asyncio.Queue[tuple[str, dict, Future] | tuple[list[int], Future]], + ] = {} + self._batcher_tasks: list[Task] = [] + + # Single-thread executor for blocking tokenizer calls. + self._executor = ThreadPoolExecutor(max_workers=1) + + # === Public async API === + async def __call__(self, prompt, **kwargs): + result_future: Future = self._loop.create_future() + key = self._queue_key("encode", kwargs) + queue = self._get_queue(self._loop, key) + await queue.put((prompt, kwargs, result_future)) + return await result_future + + async def decode(self, token_ids, **kwargs): + result_future: Future = self._loop.create_future() + key = self._queue_key("decode", kwargs) + queue = self._get_queue(self._loop, key) + await queue.put((token_ids, result_future)) + return await result_future + + # === Internal helpers === + def _get_queue( + self, loop: asyncio.AbstractEventLoop, key: tuple + ) -> asyncio.Queue[tuple[str, dict, Future] | tuple[list[int], Future]]: + """Get the request queue for the given operation key, creating a new + queue and batcher task if needed.""" + queue = self._queues.get(key) + if queue is None: + self._queues[key] = queue = asyncio.Queue() + if key[0] == "encode": + can_batch = key[1] != "other" + coro = self._batch_encode_loop(queue, can_batch) + else: + assert key[0] == "decode", f"Unknown operation type: {key[0]}." + coro = self._batch_decode_loop(queue) + self._batcher_tasks.append(loop.create_task(coro)) + return queue + + async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool): + """Batch incoming encode requests for efficiency.""" + while True: + prompt, kwargs, result_future = await queue.get() + prompts = [prompt] + kwargs_list = [kwargs] + result_futures = [result_future] + deadline = self._loop.time() + self.batch_wait_timeout_s + + while len(prompts) < self.max_batch_size: + timeout = deadline - self._loop.time() + if timeout <= 0: + break + try: + prompt, kwargs, result_future = await asyncio.wait_for( + queue.get(), timeout + ) + prompts.append(prompt) + result_futures.append(result_future) + if not can_batch: + kwargs_list.append(kwargs) + except asyncio.TimeoutError: + break + + try: + # If every request uses identical kwargs we can run a single + # batched tokenizer call for a big speed-up. + if can_batch and len(prompts) > 1: + batch_encode_fn = partial(self.tokenizer, prompts, **kwargs) + results = await self._loop.run_in_executor( + self._executor, batch_encode_fn + ) + + for i, fut in enumerate(result_futures): + if not fut.done(): + data = {k: v[i] for k, v in results.items()} + fut.set_result(BatchEncoding(data)) + else: + encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [ + self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs) + ] + results = await self._loop.run_in_executor( + self._executor, encode_fn + ) + + for fut, res in zip(result_futures, results): + if not fut.done(): + fut.set_result(res) + except Exception as e: + for fut in result_futures: + if not fut.done(): + fut.set_exception(e) + + async def _batch_decode_loop(self, queue: asyncio.Queue): + """Batch incoming decode requests for efficiency.""" + while True: + token_ids, result_future = await queue.get() + token_ids_list = [token_ids] + result_futures = [result_future] + deadline = self._loop.time() + self.batch_wait_timeout_s + + while len(token_ids_list) < self.max_batch_size: + timeout = deadline - self._loop.time() + if timeout <= 0: + break + try: + token_ids, result_future = await asyncio.wait_for( + queue.get(), timeout + ) + token_ids_list.append(token_ids) + result_futures.append(result_future) + except asyncio.TimeoutError: + break + + try: + # Perform a single batched decode call for all requests + results = await self._loop.run_in_executor( + self._executor, self.tokenizer.batch_decode, token_ids_list + ) + for fut, res in zip(result_futures, results): + if not fut.done(): + fut.set_result(res) + except Exception as e: + for fut in result_futures: + if not fut.done(): + fut.set_exception(e) + + def _queue_key(self, op: str, kwargs: dict) -> tuple: + """ + Return a normalized key describing operation + kwargs. + + - `add_special_tokens`: {True/False} + - `truncation`: {True/False} + - If `truncation` is False (`max_length` is None), + returns a key for a can_batch queue. + - If `truncation` is True and `max_length` is None or equals + `tokenizer.model_max_length`, returns a key for a can_batch queue. + - Otherwise, returns a key for a cannot_batch queue. + + Examples: + - Decode: ("decode",) + - Encode typical: + ("encode", add_special_tokens, bool_truncation, max_length_label) + - Fallback: ("encode", "other") + """ + + if op == "decode": + return ("decode",) + + add_special_tokens = kwargs.get("add_special_tokens", True) + truncation = kwargs.get("truncation", False) + max_length = kwargs.get("max_length") + + if not truncation: + return "encode", add_special_tokens, False, None + + model_max = getattr(self.tokenizer, "model_max_length", None) + if max_length is None or (model_max is not None and max_length == model_max): + return "encode", add_special_tokens, True, "model_max" + + return "encode", "other" + + def __del__(self): + if ( + (tasks := getattr(self, "_batcher_tasks", None)) + and (loop := getattr(self, "_loop", None)) + and not loop.is_closed() + ): + + def cancel_tasks(): + for task in tasks: + task.cancel() + + loop.call_soon_threadsafe(cancel_tasks) + + +def cancel_task_threadsafe(task: Task): + if task and not task.done(): + run_in_loop(task.get_loop(), task.cancel) + + +def make_async( + func: Callable[P, T], + executor: Executor | None = None, +) -> Callable[P, Awaitable[T]]: + """ + Take a blocking function, and run it on in an executor thread. + + This function prevents the blocking function from blocking the + asyncio event loop. + The code in this function needs to be thread safe. + """ + + def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Future[T]: + loop = asyncio.get_event_loop() + p_func = partial(func, *args, **kwargs) + return loop.run_in_executor(executor=executor, func=p_func) + + return _async_wrapper + + +def run_in_loop(loop: AbstractEventLoop, function: Callable, *args): + if in_loop(loop): + function(*args) + elif not loop.is_closed(): + loop.call_soon_threadsafe(function, *args) + + +def in_loop(event_loop: AbstractEventLoop) -> bool: + try: + return asyncio.get_running_loop() == event_loop + except RuntimeError: + return False + + +async def merge_async_iterators( + *iterators: AsyncGenerator[T, None], +) -> AsyncGenerator[tuple[int, T], None]: + """Merge multiple asynchronous iterators into a single iterator. + + This method handle the case where some iterators finish before others. + When it yields, it yields a tuple (i, item) where i is the index of the + iterator that yields the item. + """ + if len(iterators) == 1: + # Fast-path single iterator case. + async for item in iterators[0]: + yield 0, item + return + + loop = asyncio.get_running_loop() + + awaits = {loop.create_task(anext(it)): (i, it) for i, it in enumerate(iterators)} + try: + while awaits: + done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED) + for d in done: + pair = awaits.pop(d) + try: + item = await d + i, it = pair + awaits[loop.create_task(anext(it))] = pair + yield i, item + except StopAsyncIteration: + pass + finally: + # Cancel any remaining iterators + for f, (_, it) in awaits.items(): + with contextlib.suppress(BaseException): + f.cancel() + await it.aclose() + + +async def collect_from_async_generator(iterator: AsyncGenerator[T, None]) -> list[T]: + """Collect all items from an async generator into a list.""" + items = [] + async for item in iterator: + items.append(item) + return items diff --git a/vllm/utils/cache.py b/vllm/utils/cache.py index a57ef9b70ccc..d5e08caa8a1e 100644 --- a/vllm/utils/cache.py +++ b/vllm/utils/cache.py @@ -1,11 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - from collections import UserDict -from collections.abc import Hashable, Iterator, KeysView, Mapping +from collections.abc import Callable, Hashable, Iterator, KeysView, Mapping from types import MappingProxyType -from typing import Callable, Generic, NamedTuple, TypeVar, Union, cast, overload +from typing import Generic, NamedTuple, TypeVar, cast, overload import cachetools @@ -43,7 +41,7 @@ def hit_ratio(self) -> float: return self.hits / self.total - def __sub__(self, other: CacheInfo): + def __sub__(self, other: "CacheInfo"): return CacheInfo( hits=self.hits - other.hits, total=self.total - other.total, @@ -129,12 +127,10 @@ def touch(self, key: _K) -> None: def get(self, key: _K, /) -> _V | None: ... @overload - def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: ... + def get(self, key: _K, /, default: _V | _T) -> _V | _T: ... - def get( - self, key: _K, /, default: Union[_V, _T] | None = None - ) -> Union[_V, _T] | None: - value: Union[_V, _T] | None + def get(self, key: _K, /, default: _V | _T | None = None) -> _V | _T | None: + value: _V | _T | None if key in self: value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] @@ -149,12 +145,10 @@ def get( def pop(self, key: _K) -> _V: ... @overload - def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: ... + def pop(self, key: _K, default: _V | _T) -> _V | _T: ... - def pop( - self, key: _K, default: Union[_V, _T] | None = None - ) -> Union[_V, _T] | None: - value: Union[_V, _T] | None + def pop(self, key: _K, default: _V | _T | None = None) -> _V | _T | None: + value: _V | _T | None if key not in self: return default diff --git a/vllm/utils/collection_utils.py b/vllm/utils/collection_utils.py new file mode 100644 index 000000000000..57271311828c --- /dev/null +++ b/vllm/utils/collection_utils.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Contains helpers that are applied to collections. + +This is similar in concept to the `collections` module. +""" + +from collections import UserDict, defaultdict +from collections.abc import Callable, Generator, Hashable, Iterable, Mapping +from typing import Generic, Literal, TypeVar + +from typing_extensions import TypeIs, assert_never + +T = TypeVar("T") +U = TypeVar("U") + +_K = TypeVar("_K", bound=Hashable) +_V = TypeVar("_V") + + +class ClassRegistry(UserDict[type[T], _V]): + """ + A registry that acts like a dictionary but searches for other classes + in the MRO if the original class is not found. + """ + + def __getitem__(self, key: type[T]) -> _V: + for cls in key.mro(): + if cls in self.data: + return self.data[cls] + + raise KeyError(key) + + def __contains__(self, key: object) -> bool: + return self.contains(key) + + def contains(self, key: object, *, strict: bool = False) -> bool: + if not isinstance(key, type): + return False + + if strict: + return key in self.data + + return any(cls in self.data for cls in key.mro()) + + +class LazyDict(Mapping[str, T], Generic[T]): + """ + Evaluates dictionary items only when they are accessed. + + Adapted from: https://stackoverflow.com/a/47212782/5082708 + """ + + def __init__(self, factory: dict[str, Callable[[], T]]): + self._factory = factory + self._dict: dict[str, T] = {} + + def __getitem__(self, key: str) -> T: + if key not in self._dict: + if key not in self._factory: + raise KeyError(key) + self._dict[key] = self._factory[key]() + return self._dict[key] + + def __setitem__(self, key: str, value: Callable[[], T]): + self._factory[key] = value + + def __iter__(self): + return iter(self._factory) + + def __len__(self): + return len(self._factory) + + +def as_list(maybe_list: Iterable[T]) -> list[T]: + """Convert iterable to list, unless it's already a list.""" + return maybe_list if isinstance(maybe_list, list) else list(maybe_list) + + +def as_iter(obj: T | Iterable[T]) -> Iterable[T]: + if isinstance(obj, str) or not isinstance(obj, Iterable): + return [obj] # type: ignore[list-item] + return obj + + +def is_list_of( + value: object, + typ: type[T] | tuple[type[T], ...], + *, + check: Literal["first", "all"] = "first", +) -> TypeIs[list[T]]: + if not isinstance(value, list): + return False + + if check == "first": + return len(value) == 0 or isinstance(value[0], typ) + elif check == "all": + return all(isinstance(v, typ) for v in value) + + assert_never(check) + + +def chunk_list(lst: list[T], chunk_size: int) -> Generator[list[T]]: + """Yield successive chunk_size chunks from lst.""" + for i in range(0, len(lst), chunk_size): + yield lst[i : i + chunk_size] + + +def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]: + """Flatten a list of lists to a single list.""" + return [item for sublist in lists for item in sublist] + + +def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]): + """ + Unlike [`itertools.groupby`][], groups are not broken by + non-contiguous data. + """ + groups = defaultdict[_K, list[_V]](list) + + for value in values: + groups[key(value)].append(value) + + return groups.items() + + +def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None: + """Swap values between two keys.""" + v1 = obj.get(key1) + v2 = obj.get(key2) + if v1 is not None: + obj[key2] = v1 + else: + obj.pop(key2, None) + if v2 is not None: + obj[key1] = v2 + else: + obj.pop(key1, None) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 8f8f25f1302d..2e8cd302b0f5 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -5,24 +5,24 @@ Users of vLLM should always import **only** these wrappers. """ -from __future__ import annotations - import functools import importlib import os -from typing import Any, Callable, NoReturn +from collections.abc import Callable +from typing import Any, NoReturn import torch import vllm.envs as envs from vllm.logger import logger from vllm.platforms import current_platform -from vllm.utils import cdiv, has_deep_gemm +from vllm.utils import cdiv +from vllm.utils.import_utils import has_deep_gemm @functools.cache def is_deep_gemm_supported() -> bool: - """Return ``True`` if DeepGEMM is supported on the current platform. + """Return `True` if DeepGEMM is supported on the current platform. Currently, only Hopper and Blackwell GPUs are supported. """ is_supported_arch = current_platform.is_cuda() and ( @@ -34,7 +34,7 @@ def is_deep_gemm_supported() -> bool: @functools.cache def is_deep_gemm_e8m0_used() -> bool: - """Return ``True`` if vLLM is configured to use DeepGEMM " + """Return `True` if vLLM is configured to use DeepGEMM " "E8M0 scale on a Hopper or Blackwell-class GPU. """ if not is_deep_gemm_supported(): @@ -76,6 +76,7 @@ def _missing(*_: Any, **__: Any) -> NoReturn: _fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None _get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None _get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None +_get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None def _lazy_init() -> None: @@ -84,7 +85,7 @@ def _lazy_init() -> None: global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl global _get_paged_mqa_logits_metadata_impl global _get_mn_major_tma_aligned_tensor_impl - + global _get_mk_alignment_for_contiguous_layout_impl # fast path if ( _fp8_gemm_nt_impl is not None @@ -93,6 +94,7 @@ def _lazy_init() -> None: or _fp8_mqa_logits_impl is not None or _fp8_paged_mqa_logits_impl is not None or _get_paged_mqa_logits_metadata_impl is not None + or _get_mk_alignment_for_contiguous_layout_impl is not None ): return @@ -119,6 +121,9 @@ def _lazy_init() -> None: _get_mn_major_tma_aligned_tensor_impl = getattr( _dg, "get_mn_major_tma_aligned_tensor", None ) + _get_mk_alignment_for_contiguous_layout_impl = getattr( + _dg, "get_mk_alignment_for_contiguous_layout", None + ) def get_num_sms() -> int: @@ -127,6 +132,15 @@ def get_num_sms() -> int: return int(_dg.get_num_sms()) +@functools.cache +def get_mk_alignment_for_contiguous_layout() -> list[int]: + _lazy_init() + if _get_mk_alignment_for_contiguous_layout_impl is None: + return _missing() + mk_align_size = _get_mk_alignment_for_contiguous_layout_impl() + return [mk_align_size, mk_align_size] + + def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: """Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor""" _lazy_init() @@ -298,9 +312,9 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): """Return a global difference metric for unit tests. DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element - error, causing ``torch.testing.assert_close`` to fail. Instead of checking + error, causing `torch.testing.assert_close` to fail. Instead of checking every element, we compute a cosine-style similarity over the whole tensor - and report ``1 - sim``. Once kernel accuracy improves this helper can be + and report `1 - sim`. Once kernel accuracy improves this helper can be removed. """ @@ -339,4 +353,5 @@ def should_use_deepgemm_for_fp8_linear( "get_num_sms", "should_use_deepgemm_for_fp8_linear", "get_col_major_tma_aligned_tensor", + "get_mk_alignment_for_contiguous_layout", ] diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 159d19bfad31..d7e4ea2e0388 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -5,14 +5,14 @@ Users of vLLM should always import **only** these wrappers. """ -from __future__ import annotations - import contextlib import functools import importlib import importlib.util import os -from typing import Any, Callable, NoReturn +import shutil +from collections.abc import Callable +from typing import Any, NoReturn import requests import torch @@ -34,10 +34,17 @@ @functools.cache def has_flashinfer() -> bool: - """Return ``True`` if FlashInfer is available.""" + """Return `True` if FlashInfer is available.""" # Use find_spec to check if the module exists without importing it # This avoids potential CUDA initialization side effects - return importlib.util.find_spec("flashinfer") is not None + if importlib.util.find_spec("flashinfer") is None: + logger.debug_once("FlashInfer unavailable since package was not found") + return False + # Also check if nvcc is available since it's required to JIT compile flashinfer + if shutil.which("nvcc") is None: + logger.debug_once("FlashInfer unavailable since nvcc was not found") + return False + return True def _missing(*_: Any, **__: Any) -> NoReturn: @@ -107,13 +114,13 @@ def wrapper(*args, **kwargs): @functools.cache def has_flashinfer_comm() -> bool: - """Return ``True`` if FlashInfer comm module is available.""" + """Return `True` if FlashInfer comm module is available.""" return has_flashinfer() and importlib.util.find_spec("flashinfer.comm") is not None @functools.cache def has_flashinfer_all2all() -> bool: - """Return ``True`` if FlashInfer mnnvl all2all is available.""" + """Return `True` if FlashInfer mnnvl all2all is available.""" if not has_flashinfer_comm(): return False @@ -134,7 +141,7 @@ def has_flashinfer_all2all() -> bool: @functools.cache def has_flashinfer_moe() -> bool: - """Return ``True`` if FlashInfer MoE module is available.""" + """Return `True` if FlashInfer MoE module is available.""" return ( has_flashinfer() and importlib.util.find_spec("flashinfer.fused_moe") is not None @@ -143,7 +150,7 @@ def has_flashinfer_moe() -> bool: @functools.cache def has_flashinfer_cutlass_fused_moe() -> bool: - """Return ``True`` if FlashInfer CUTLASS fused MoE is available.""" + """Return `True` if FlashInfer CUTLASS fused MoE is available.""" if not has_flashinfer_moe(): return False @@ -164,7 +171,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool: @functools.cache def has_nvidia_artifactory() -> bool: - """Return ``True`` if NVIDIA's artifactory is accessible. + """Return `True` if NVIDIA's artifactory is accessible. This checks connectivity to the kernel inference library artifactory which is required for downloading certain cubin kernels like TRTLLM FHMA. @@ -211,9 +218,9 @@ def _force_use_trtllm_attention(env_value: bool | None) -> bool | None: def force_use_trtllm_attention() -> bool | None: """ - Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set, - return ``True`` if TRTLLM attention is forced to be used, - return ``False`` if TRTLLM attention is forced to be not used. + Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set, + return `True` if TRTLLM attention is forced to be used, + return `False` if TRTLLM attention is forced to be not used. """ return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION) @@ -237,7 +244,7 @@ def use_trtllm_attention( has_sinks: bool = False, has_spec: bool = False, ) -> bool: - """Return ``True`` if TRTLLM attention is used.""" + """Return `True` if TRTLLM attention is used.""" force_use_trtllm = force_use_trtllm_attention() # Environment variable is set to 0 - respect it @@ -269,11 +276,6 @@ def use_trtllm_attention( # Must use TRTLLM attention if query is FP8 quantized if q_dtype == current_platform.fp8_dtype(): - if has_sinks: - raise RuntimeError( - "TRTLLM FP8-qkv kernel is not supported for attention sinks. " - "Use kv_cache_dtype=auto for now." - ) logger.info_once("Using TRTLLM attention (query is quantized).") return True @@ -386,8 +388,6 @@ def flashinfer_scaled_fp4_mm( assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2 assert a.stride(-1) == 1 and b.stride(-1) == 1 assert a.shape[1] == b.shape[1] - assert block_scale_a.shape[1] == a.shape[1] // 8 - assert block_scale_b.shape[1] == b.shape[1] // 8 if backend == "cutlass": block_scale_a = block_scale_a.view(torch.uint8) diff --git a/vllm/utils/func_utils.py b/vllm/utils/func_utils.py new file mode 100644 index 000000000000..c061a0dad552 --- /dev/null +++ b/vllm/utils/func_utils.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Contains helpers that are applied to functions. + +This is similar in concept to the `functools` module. +""" + +import inspect +import threading +import warnings +from collections.abc import Callable, Mapping +from functools import lru_cache, partial, wraps +from typing import Any, TypeVar + +from typing_extensions import ParamSpec + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +P = ParamSpec("P") +T = TypeVar("T") +F = TypeVar("F", bound=Callable[..., Any]) + + +def identity(value: T, **kwargs) -> T: + """Returns the first provided value.""" + return value + + +def run_once(f: Callable[P, None]) -> Callable[P, None]: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: + if wrapper.has_run: # type: ignore[attr-defined] + return + + with wrapper.lock: # type: ignore[attr-defined] + if not wrapper.has_run: # type: ignore[attr-defined] + wrapper.has_run = True # type: ignore[attr-defined] + return f(*args, **kwargs) + + wrapper.has_run = False # type: ignore[attr-defined] + wrapper.lock = threading.Lock() # type: ignore[attr-defined] + return wrapper + + +def deprecate_args( + start_index: int, + is_deprecated: bool | Callable[[], bool] = True, + additional_message: str | None = None, +) -> Callable[[F], F]: + if not callable(is_deprecated): + is_deprecated = partial(identity, is_deprecated) + + def wrapper(fn: F) -> F: + params = inspect.signature(fn).parameters + pos_types = ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + pos_kws = [kw for kw, param in params.items() if param.kind in pos_types] + + @wraps(fn) + def inner(*args, **kwargs): + if is_deprecated(): + deprecated_args = pos_kws[start_index : len(args)] + if deprecated_args: + msg = ( + f"The positional arguments {deprecated_args} are " + "deprecated and will be removed in a future update." + ) + if additional_message is not None: + msg += f" {additional_message}" + + warnings.warn( + DeprecationWarning(msg), + stacklevel=3, # The inner function takes up one level + ) + + return fn(*args, **kwargs) + + return inner # type: ignore + + return wrapper + + +def deprecate_kwargs( + *kws: str, + is_deprecated: bool | Callable[[], bool] = True, + additional_message: str | None = None, +) -> Callable[[F], F]: + deprecated_kws = set(kws) + + if not callable(is_deprecated): + is_deprecated = partial(identity, is_deprecated) + + def wrapper(fn: F) -> F: + @wraps(fn) + def inner(*args, **kwargs): + if is_deprecated(): + deprecated_kwargs = kwargs.keys() & deprecated_kws + if deprecated_kwargs: + msg = ( + f"The keyword arguments {deprecated_kwargs} are " + "deprecated and will be removed in a future update." + ) + if additional_message is not None: + msg += f" {additional_message}" + + warnings.warn( + DeprecationWarning(msg), + stacklevel=3, # The inner function takes up one level + ) + + return fn(*args, **kwargs) + + return inner # type: ignore + + return wrapper + + +@lru_cache +def supports_kw( + callable: Callable[..., object], + kw_name: str, + *, + requires_kw_only: bool = False, + allow_var_kwargs: bool = True, +) -> bool: + """Check if a keyword is a valid kwarg for a callable; if requires_kw_only + disallows kwargs names that can also be positional arguments. + """ + params = inspect.signature(callable).parameters + if not params: + return False + + param_val = params.get(kw_name) + + # Types where the it may be valid, i.e., explicitly defined & nonvariadic + passable_kw_types = set( + ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + ) + + if param_val: + is_sig_param = param_val.kind in passable_kw_types + # We want kwargs only, but this is passable as a positional arg + if ( + requires_kw_only + and is_sig_param + and param_val.kind != inspect.Parameter.KEYWORD_ONLY + ): + return False + if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or ( + not requires_kw_only and is_sig_param + ): + return True + + # If we're okay with var-kwargs, it's supported as long as + # the kw_name isn't something like *args, **kwargs + if allow_var_kwargs: + # Get the last param; type is ignored here because params is a proxy + # mapping, but it wraps an ordered dict, and they appear in order. + # Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters + last_param = params[next(reversed(params))] # type: ignore + return ( + last_param.kind == inspect.Parameter.VAR_KEYWORD + and last_param.name != kw_name + ) + + return False + + +def get_allowed_kwarg_only_overrides( + callable: Callable[..., object], + overrides: Mapping[str, object] | None, + *, + requires_kw_only: bool = True, + allow_var_kwargs: bool = False, +) -> dict[str, Any]: + """ + Given a callable which has one or more keyword only params and a dict + mapping param names to values, drop values that can be not be kwarg + expanded to overwrite one or more keyword-only args. This is used in a + few places to handle custom processor overrides for multimodal models, + e.g., for profiling when processor options provided by the user + may affect the number of mm tokens per instance. + + Args: + callable: Callable which takes 0 or more keyword only arguments. + If None is provided, all overrides names are allowed. + overrides: Potential overrides to be used when invoking the callable. + allow_var_kwargs: Allows overrides that are expandable for var kwargs. + + Returns: + Dictionary containing the kwargs to be leveraged which may be used + to overwrite one or more keyword only arguments when invoking the + callable. + """ + if not overrides: + return {} + + # Drop any mm_processor_kwargs provided by the user that + # are not kwargs, unless it can fit it var_kwargs param + filtered_overrides = { + kwarg_name: val + for kwarg_name, val in overrides.items() + if supports_kw( + callable, + kwarg_name, + requires_kw_only=requires_kw_only, + allow_var_kwargs=allow_var_kwargs, + ) + } + + # If anything is dropped, log a warning + dropped_keys = overrides.keys() - filtered_overrides.keys() + if dropped_keys: + if requires_kw_only: + logger.warning( + "The following intended overrides are not keyword-only args " + "and will be dropped: %s", + dropped_keys, + ) + else: + logger.warning( + "The following intended overrides are not keyword args " + "and will be dropped: %s", + dropped_keys, + ) + + return filtered_overrides diff --git a/vllm/utils/gc_utils.py b/vllm/utils/gc_utils.py index e3b5b61dd364..4dd85ef26f34 100644 --- a/vllm/utils/gc_utils.py +++ b/vllm/utils/gc_utils.py @@ -5,9 +5,9 @@ import time from collections import Counter from contextlib import suppress -from typing import Any, Optional +from typing import Any -from vllm.envs import VLLM_GC_DEBUG +import vllm.envs as envs from vllm.logger import init_logger logger = init_logger(__name__) @@ -21,7 +21,7 @@ class GCDebugConfig: - '{"top_objects":5}': enable GC debugger with top 5 collected objects """ - def __init__(self, gc_debug_conf: Optional[str] = None) -> None: + def __init__(self, gc_debug_conf: str | None = None) -> None: self.enabled: bool = False self.top_objects: int = -1 @@ -36,8 +36,8 @@ def __init__(self, gc_debug_conf: Optional[str] = None) -> None: self.top_objects = json_conf.get("top_objects", -1) except Exception: self.enabled = False - logger.error("Failed to parse VLLM_GC_DEBUG(%s)", VLLM_GC_DEBUG) - logger.info("GC Debug Config. %s", str(self)) + logger.error("Failed to parse VLLM_GC_DEBUG(%s)", envs.VLLM_GC_DEBUG) + logger.debug("GC Debug Config. %s", str(self)) def __repr__(self) -> str: return f"enabled:{self.enabled},top_objects:{self.top_objects}" @@ -93,7 +93,7 @@ def maybe_attach_gc_debug_callback() -> None: """ Attached a callback for GC debug when VLLM_GC_DEBUG is enabled. """ - config = GCDebugConfig(VLLM_GC_DEBUG) + config = GCDebugConfig(envs.VLLM_GC_DEBUG) if config.enabled: debugger: GCDebugger = GCDebugger(config) diff --git a/vllm/utils/hashing.py b/vllm/utils/hashing.py new file mode 100644 index 000000000000..49f4f13d115f --- /dev/null +++ b/vllm/utils/hashing.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import hashlib +import pickle +from collections.abc import Callable +from typing import Any + +import cbor2 + + +def sha256(input: Any) -> bytes: + """Hash any picklable Python object using SHA-256. + + The input is serialized using pickle before hashing, which allows + arbitrary Python objects to be used. Note that this function does + not use a hash seed—if you need one, prepend it explicitly to the input. + + Args: + input: Any picklable Python object. + + Returns: + Bytes representing the SHA-256 hash of the serialized input. + """ + input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) + return hashlib.sha256(input_bytes).digest() + + +def sha256_cbor(input: Any) -> bytes: + """Hash objects using CBOR serialization and SHA-256. + + This option is useful for non-Python-dependent serialization and hashing. + + Args: + input: Object to be serialized and hashed. Supported types include + basic Python types and complex structures like lists, tuples, and + dictionaries. + Custom classes must implement CBOR serialization methods. + + Returns: + Bytes representing the SHA-256 hash of the CBOR serialized input. + """ + input_bytes = cbor2.dumps(input, canonical=True) + return hashlib.sha256(input_bytes).digest() + + +def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: + """Get a hash function by name, or raise an error if the function is not found. + + Args: + hash_fn_name: Name of the hash function. + + Returns: + A hash function. + """ + if hash_fn_name == "sha256": + return sha256 + if hash_fn_name == "sha256_cbor": + return sha256_cbor + + raise ValueError(f"Unsupported hash function: {hash_fn_name}") diff --git a/vllm/utils/import_utils.py b/vllm/utils/import_utils.py new file mode 100644 index 000000000000..65f588b52e5e --- /dev/null +++ b/vllm/utils/import_utils.py @@ -0,0 +1,362 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Contains helpers related to importing modules. + +This is similar in concept to the `importlib` module. +""" + +import importlib.metadata +import importlib.util +import os +import sys +from functools import cache +from types import ModuleType +from typing import Any + +import regex as re +from typing_extensions import Never + + +def import_from_path(module_name: str, file_path: str | os.PathLike): + """ + Import a Python file according to its file path. + + Based on the official recipe: + https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly + """ + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ModuleNotFoundError(f"No module named {module_name!r}") + + assert spec.loader is not None + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def resolve_obj_by_qualname(qualname: str) -> Any: + """ + Resolve an object by its fully-qualified class name. + """ + module_name, obj_name = qualname.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, obj_name) + + +@cache +def get_vllm_optional_dependencies(): + metadata = importlib.metadata.metadata("vllm") + requirements = metadata.get_all("Requires-Dist", []) + extras = metadata.get_all("Provides-Extra", []) + + return { + extra: [ + re.split(r";|>=|<=|==", req)[0] + for req in requirements + if req.endswith(f'extra == "{extra}"') + ] + for extra in extras + } + + +class _PlaceholderBase: + """ + Disallows downstream usage of placeholder modules. + + We need to explicitly override each dunder method because + [`__getattr__`][vllm.utils.import_utils._PlaceholderBase.__getattr__] + is not called when they are accessed. + + Info: + [Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup) + """ + + def __getattr__(self, key: str) -> Never: + """ + The main class should implement this to throw an error + for attribute accesses representing downstream usage. + """ + raise NotImplementedError + + # [Basic customization] + + def __lt__(self, other: object): + return self.__getattr__("__lt__") + + def __le__(self, other: object): + return self.__getattr__("__le__") + + def __eq__(self, other: object): + return self.__getattr__("__eq__") + + def __ne__(self, other: object): + return self.__getattr__("__ne__") + + def __gt__(self, other: object): + return self.__getattr__("__gt__") + + def __ge__(self, other: object): + return self.__getattr__("__ge__") + + def __hash__(self): + return self.__getattr__("__hash__") + + def __bool__(self): + return self.__getattr__("__bool__") + + # [Callable objects] + + def __call__(self, *args: object, **kwargs: object): + return self.__getattr__("__call__") + + # [Container types] + + def __len__(self): + return self.__getattr__("__len__") + + def __getitem__(self, key: object): + return self.__getattr__("__getitem__") + + def __setitem__(self, key: object, value: object): + return self.__getattr__("__setitem__") + + def __delitem__(self, key: object): + return self.__getattr__("__delitem__") + + # __missing__ is optional according to __getitem__ specification, + # so it is skipped + + # __iter__ and __reversed__ have a default implementation + # based on __len__ and __getitem__, so they are skipped. + + # [Numeric Types] + + def __add__(self, other: object): + return self.__getattr__("__add__") + + def __sub__(self, other: object): + return self.__getattr__("__sub__") + + def __mul__(self, other: object): + return self.__getattr__("__mul__") + + def __matmul__(self, other: object): + return self.__getattr__("__matmul__") + + def __truediv__(self, other: object): + return self.__getattr__("__truediv__") + + def __floordiv__(self, other: object): + return self.__getattr__("__floordiv__") + + def __mod__(self, other: object): + return self.__getattr__("__mod__") + + def __divmod__(self, other: object): + return self.__getattr__("__divmod__") + + def __pow__(self, other: object, modulo: object = ...): + return self.__getattr__("__pow__") + + def __lshift__(self, other: object): + return self.__getattr__("__lshift__") + + def __rshift__(self, other: object): + return self.__getattr__("__rshift__") + + def __and__(self, other: object): + return self.__getattr__("__and__") + + def __xor__(self, other: object): + return self.__getattr__("__xor__") + + def __or__(self, other: object): + return self.__getattr__("__or__") + + # r* and i* methods have lower priority than + # the methods for left operand so they are skipped + + def __neg__(self): + return self.__getattr__("__neg__") + + def __pos__(self): + return self.__getattr__("__pos__") + + def __abs__(self): + return self.__getattr__("__abs__") + + def __invert__(self): + return self.__getattr__("__invert__") + + # __complex__, __int__ and __float__ have a default implementation + # based on __index__, so they are skipped. + + def __index__(self): + return self.__getattr__("__index__") + + def __round__(self, ndigits: object = ...): + return self.__getattr__("__round__") + + def __trunc__(self): + return self.__getattr__("__trunc__") + + def __floor__(self): + return self.__getattr__("__floor__") + + def __ceil__(self): + return self.__getattr__("__ceil__") + + # [Context managers] + + def __enter__(self): + return self.__getattr__("__enter__") + + def __exit__(self, *args: object, **kwargs: object): + return self.__getattr__("__exit__") + + +class PlaceholderModule(_PlaceholderBase): + """ + A placeholder object to use when a module does not exist. + + This enables more informative errors when trying to access attributes + of a module that does not exist. + """ + + def __init__(self, name: str) -> None: + super().__init__() + + # Apply name mangling to avoid conflicting with module attributes + self.__name = name + + def placeholder_attr(self, attr_path: str): + return _PlaceholderModuleAttr(self, attr_path) + + def __getattr__(self, key: str) -> Never: + name = self.__name + + try: + importlib.import_module(name) + except ImportError as exc: + for extra, names in get_vllm_optional_dependencies().items(): + if name in names: + msg = f"Please install vllm[{extra}] for {extra} support" + raise ImportError(msg) from exc + + raise exc + + raise AssertionError( + "PlaceholderModule should not be used " + "when the original module can be imported" + ) + + +class _PlaceholderModuleAttr(_PlaceholderBase): + def __init__(self, module: PlaceholderModule, attr_path: str) -> None: + super().__init__() + + # Apply name mangling to avoid conflicting with module attributes + self.__module = module + self.__attr_path = attr_path + + def placeholder_attr(self, attr_path: str): + return _PlaceholderModuleAttr(self.__module, f"{self.__attr_path}.{attr_path}") + + def __getattr__(self, key: str) -> Never: + getattr(self.__module, f"{self.__attr_path}.{key}") + + raise AssertionError( + "PlaceholderModule should not be used " + "when the original module can be imported" + ) + + +class LazyLoader(ModuleType): + """ + `LazyLoader` module borrowed from [Tensorflow] + (https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py) + with an addition of "module caching". + + Lazily import a module, mainly to avoid pulling in large dependencies. + Modules such as `xgrammar` might do additional side effects, so we + only want to use this when it is needed, delaying all eager effects. + """ + + def __init__( + self, + local_name: str, + parent_module_globals: dict[str, Any], + name: str, + ): + self._local_name = local_name + self._parent_module_globals = parent_module_globals + self._module: ModuleType | None = None + + super().__init__(str(name)) + + def _load(self) -> ModuleType: + # Import the target module and insert it into the parent's namespace + try: + module = importlib.import_module(self.__name__) + self._parent_module_globals[self._local_name] = module + # The additional add to sys.modules + # ensures library is actually loaded. + sys.modules[self._local_name] = module + except ModuleNotFoundError as err: + raise err from None + + # Update this object's dict so that if someone keeps a + # reference to the LazyLoader, lookups are efficient + # (__getattr__ is only called on lookups that fail). + self.__dict__.update(module.__dict__) + return module + + def __getattr__(self, item: Any) -> Any: + if self._module is None: + self._module = self._load() + return getattr(self._module, item) + + def __dir__(self) -> list[str]: + if self._module is None: + self._module = self._load() + return dir(self._module) + + +# Optional dependency detection utilities +@cache +def _has_module(module_name: str) -> bool: + """Return True if *module_name* can be found in the current environment. + + The result is cached so that subsequent queries for the same module incur + no additional overhead. + """ + return importlib.util.find_spec(module_name) is not None + + +def has_pplx() -> bool: + """Whether the optional `pplx_kernels` package is available.""" + return _has_module("pplx_kernels") + + +def has_deep_ep() -> bool: + """Whether the optional `deep_ep` package is available.""" + return _has_module("deep_ep") + + +def has_deep_gemm() -> bool: + """Whether the optional `deep_gemm` package is available.""" + return _has_module("deep_gemm") + + +def has_triton_kernels() -> bool: + """Whether the optional `triton_kernels` package is available.""" + return _has_module("triton_kernels") + + +def has_tilelang() -> bool: + """Whether the optional `tilelang` package is available.""" + return _has_module("tilelang") diff --git a/vllm/utils/jsontree.py b/vllm/utils/jsontree.py index dcdc6ccb4c63..cde9aa6ff901 100644 --- a/vllm/utils/jsontree.py +++ b/vllm/utils/jsontree.py @@ -2,9 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Helper functions to work with nested JSON structures.""" -from collections.abc import Iterable +from collections.abc import Callable, Iterable from functools import reduce -from typing import TYPE_CHECKING, Callable, TypeVar, Union, cast, overload +from typing import TYPE_CHECKING, TypeAlias, TypeVar, cast, overload if TYPE_CHECKING: import torch @@ -14,23 +14,20 @@ _T = TypeVar("_T") _U = TypeVar("_U") -JSONTree = Union[ - dict[str, "JSONTree[_T]"], - list["JSONTree[_T]"], - tuple["JSONTree[_T]", ...], - _T, -] +JSONTree: TypeAlias = ( + dict[str, "JSONTree[_T]"] | list["JSONTree[_T]"] | tuple["JSONTree[_T]", ...] | _T +) """A nested JSON structure where the leaves need not be JSON-serializable.""" -_JSONTree = Union[ - dict[str, "JSONTree[_T]"], - list["JSONTree[_T]"], - tuple["JSONTree[_T]", ...], - dict[str, _T], - list[_T], - tuple[_T, ...], - _T, -] +_JSONTree: TypeAlias = ( + dict[str, "JSONTree[_T]"] + | list["JSONTree[_T]"] + | tuple["JSONTree[_T]", ...] + | dict[str, _T] + | list[_T] + | tuple[_T, ...] + | _T +) """ Same as `JSONTree` but with additional `Union` members to satisfy overloads. """ @@ -58,22 +55,22 @@ def json_map_leaves( @overload def json_map_leaves( func: Callable[[_T], _U], - value: Union[_T, dict[str, _T]], -) -> Union[_U, dict[str, _U]]: ... + value: _T | dict[str, _T], +) -> _U | dict[str, _U]: ... @overload def json_map_leaves( func: Callable[[_T], _U], - value: Union[_T, list[_T]], -) -> Union[_U, list[_U]]: ... + value: _T | list[_T], +) -> _U | list[_U]: ... @overload def json_map_leaves( func: Callable[[_T], _U], - value: Union[_T, tuple[_T, ...]], -) -> Union[_U, tuple[_U, ...]]: ... + value: _T | tuple[_T, ...], +) -> _U | tuple[_U, ...]: ... @overload @@ -85,8 +82,8 @@ def json_map_leaves( def json_map_leaves( func: Callable[[_T], _U], - value: Union["BatchedTensorInputs", _JSONTree[_T]], -) -> Union["BatchedTensorInputs", _JSONTree[_U]]: + value: "BatchedTensorInputs" | _JSONTree[_T], +) -> "BatchedTensorInputs" | _JSONTree[_U]: """Apply a function to each leaf in a nested JSON structure.""" if isinstance(value, dict): return { @@ -104,7 +101,7 @@ def json_map_leaves( @overload def json_reduce_leaves( func: Callable[[_T, _T], _T], - value: Union[_T, dict[str, _T]], + value: _T | dict[str, _T], /, ) -> _T: ... @@ -112,7 +109,7 @@ def json_reduce_leaves( @overload def json_reduce_leaves( func: Callable[[_T, _T], _T], - value: Union[_T, list[_T]], + value: _T | list[_T], /, ) -> _T: ... @@ -120,7 +117,7 @@ def json_reduce_leaves( @overload def json_reduce_leaves( func: Callable[[_T, _T], _T], - value: Union[_T, tuple[_T, ...]], + value: _T | tuple[_T, ...], /, ) -> _T: ... @@ -143,11 +140,11 @@ def json_reduce_leaves( def json_reduce_leaves( - func: Callable[..., Union[_T, _U]], + func: Callable[..., _T | _U], value: _JSONTree[_T], initial: _U = cast(_U, ...), # noqa: B008 /, -) -> Union[_T, _U]: +) -> _T | _U: """ Apply a function of two arguments cumulatively to each leaf in a nested JSON structure, from left to right, so as to reduce the diff --git a/vllm/utils/mem_constants.py b/vllm/utils/mem_constants.py new file mode 100644 index 000000000000..62b725fbb0f2 --- /dev/null +++ b/vllm/utils/mem_constants.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +MB_bytes = 1_000_000 +"""The number of bytes in one megabyte (MB).""" + +MiB_bytes = 1 << 20 +"""The number of bytes in one mebibyte (MiB).""" + +GB_bytes = 1_000_000_000 +"""The number of bytes in one gigabyte (GB).""" + +GiB_bytes = 1 << 30 +"""The number of bytes in one gibibyte (GiB).""" diff --git a/vllm/utils/mem_utils.py b/vllm/utils/mem_utils.py new file mode 100644 index 000000000000..c6a6757bed3b --- /dev/null +++ b/vllm/utils/mem_utils.py @@ -0,0 +1,232 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import gc +import time +from collections.abc import Generator +from dataclasses import dataclass, field +from functools import cache + +import psutil +import torch +import torch.types + +from .mem_constants import GiB_bytes + + +@cache +def get_max_shared_memory_bytes(gpu: int = 0) -> int: + """Returns the maximum shared memory per thread block in bytes.""" + from vllm import _custom_ops as ops + + max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu) + # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py + # will fail + assert max_shared_mem > 0, "max_shared_mem can not be zero" + return int(max_shared_mem) + + +def get_cpu_memory() -> int: + """Returns the total CPU memory of the node in bytes.""" + return psutil.virtual_memory().total + + +class DeviceMemoryProfiler: + def __init__(self, device: torch.types.Device | None = None): + self.device = device + + def current_memory_usage(self) -> float: + # Return the memory usage in bytes. + from vllm.platforms import current_platform + + gc.collect() + return current_platform.get_current_memory_usage(self.device) + + def __enter__(self): + self.initial_memory = self.current_memory_usage() + # This allows us to call methods of the context manager if needed + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.final_memory = self.current_memory_usage() + self.consumed_memory = self.final_memory - self.initial_memory + + # Force garbage collection + gc.collect() + + +@dataclass +class MemorySnapshot: + """Memory snapshot.""" + + torch_peak: int = 0 + free_memory: int = 0 + total_memory: int = 0 + cuda_memory: int = 0 + torch_memory: int = 0 + non_torch_memory: int = 0 + timestamp: float = 0.0 + auto_measure: bool = True + + def __post_init__(self): + if self.auto_measure: + self.measure() + + def measure(self): + from vllm.platforms import current_platform + + # we measure the torch peak memory usage via allocated_bytes, + # rather than `torch.cuda.memory_reserved()` . + # After `torch.cuda.reset_peak_memory_stats()`, + # `torch.cuda.memory_reserved()` will keep growing, and only shrink + # when we call `torch.cuda.empty_cache()` or OOM happens. + self.torch_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0) + + self.free_memory, self.total_memory = torch.cuda.mem_get_info() + shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark + if ( + current_platform.is_cuda() + and current_platform.get_device_capability() in shared_sysmem_device_mem_sms + ): + # On UMA (Orin, Thor and Spark) platform, + # where both CPU and GPU rely on system memory, + # the cudaMemGetInfo function shows the amount of free system memory + # rather than what’s actually available. + # In the case, + # torch.cuda.mem_get_info() only reports "free" memory, + # which can be lower than what is actually + # available due to not including cache memory. + # There’s also a comprehensive reference page + # that explains how you can compute the proper value yourself. + # https://docs.nvidia.com/cuda/cuda-for-tegra-appnote/#estimating-total-allocatable-device-memory-on-an-integrated-gpu-device + self.free_memory = psutil.virtual_memory().available + + self.cuda_memory = self.total_memory - self.free_memory + + # torch.cuda.memory_reserved() is how many bytes + # PyTorch gets from cuda (by calling cudaMalloc, etc.) + # this is used to measure the non-torch memory usage + self.torch_memory = torch.cuda.memory_reserved() + + self.non_torch_memory = self.cuda_memory - self.torch_memory + self.timestamp = time.time() + + def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot": + return MemorySnapshot( + torch_peak=self.torch_peak - other.torch_peak, + free_memory=self.free_memory - other.free_memory, + total_memory=self.total_memory - other.total_memory, + cuda_memory=self.cuda_memory - other.cuda_memory, + torch_memory=self.torch_memory - other.torch_memory, + non_torch_memory=self.non_torch_memory - other.non_torch_memory, + timestamp=self.timestamp - other.timestamp, + auto_measure=False, + ) + + +@dataclass +class MemoryProfilingResult: + """Memory profiling result. All numbers are in bytes.""" + + non_kv_cache_memory: int = 0 + torch_peak_increase: int = 0 + non_torch_increase: int = 0 + weights_memory: float = 0 + before_create: MemorySnapshot = field(default_factory=MemorySnapshot) + before_profile: MemorySnapshot = field(default_factory=MemorySnapshot) + after_profile: MemorySnapshot = field(default_factory=MemorySnapshot) + profile_time: float = 0.0 + + def __repr__(self) -> str: + return ( + f"Memory profiling takes {self.profile_time:.2f} seconds. " + f"Total non KV cache memory: " + f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; " + f"torch peak memory increase: " + f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; " + f"non-torch forward increase memory: " + f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; " + f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB." + ) + + +@contextlib.contextmanager +def memory_profiling( + baseline_snapshot: MemorySnapshot, weights_memory: int +) -> Generator[MemoryProfilingResult, None, None]: + """Memory profiling context manager. + baseline_snapshot: the memory snapshot before the current vLLM instance. + weights_memory: memory used by PyTorch when loading the model weights. + Note that, before loading the model weights, we also initialize the device + and distributed environment, which may consume some memory. This part is not + included in the weights_memory because PyTorch does not control it. + + The memory in one GPU can be classified into 3 categories: + 1. memory used by anything other than the current vLLM instance. + 2. memory used by torch in the current vLLM instance. + 3. memory used in the current vLLM instance, but not by torch. + + A quantitive example: + + Before creating the current vLLM instance: + category 1: 1 GiB + category 2: 0 GiB + category 3: 0 GiB + + After creating the current vLLM instance and loading the model, + (i.e. before profiling): + category 1: 1 GiB + category 2: 2 GiB (model weights take 2 GiB) + category 3: 0.5 GiB (memory used by NCCL) + + During profiling (peak): + category 1: 1 GiB + category 2: 4 GiB (peak activation tensors take 2 GiB) + category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) + + After profiling: + category 1: 1 GiB + category 2: 3 GiB (after garbage-collecting activation tensors) + category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) + + In this case, non-kv cache takes 5 GiB in total, including: + a. 2 GiB used by the model weights (category 2) + b. 2 GiB reserved for the peak activation tensors (category 2) + c. 1 GiB used by non-torch components (category 3) + + The memory used for loading weights (a.) is directly given from the argument `weights_memory`. + + The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.). + + The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.). + """ # noqa + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + result = MemoryProfilingResult() + + result.before_create = baseline_snapshot + # the part of memory used for holding the model weights + result.weights_memory = weights_memory + + result.before_profile.measure() + + yield result + + gc.collect() + torch.cuda.empty_cache() + + result.after_profile.measure() + + diff_profile = result.after_profile - result.before_profile + diff_from_create = result.after_profile - result.before_create + result.torch_peak_increase = diff_profile.torch_peak + result.non_torch_increase = diff_from_create.non_torch_memory + result.profile_time = diff_profile.timestamp + + non_torch_memory = result.non_torch_increase + peak_activation_memory = result.torch_peak_increase + result.non_kv_cache_memory = ( + non_torch_memory + peak_activation_memory + result.weights_memory + ) # noqa diff --git a/vllm/utils/nccl.py b/vllm/utils/nccl.py new file mode 100644 index 000000000000..b1459fcbd246 --- /dev/null +++ b/vllm/utils/nccl.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import importlib +import os + +import torch + +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def find_nccl_library() -> str: + """Return NCCL/RCCL shared library name to load. + + Uses `VLLM_NCCL_SO_PATH` if set; otherwise chooses by torch backend. + """ + so_file = envs.VLLM_NCCL_SO_PATH + if so_file: + logger.info( + "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file + ) + else: + if torch.version.cuda is not None: + so_file = "libnccl.so.2" + elif torch.version.hip is not None: + so_file = "librccl.so.1" + else: + raise ValueError("NCCL only supports CUDA and ROCm backends.") + logger.debug_once("Found nccl from library %s", so_file) + return so_file + + +def find_nccl_include_paths() -> list[str] | None: + """Return possible include paths containing `nccl.h`. + + Considers `VLLM_NCCL_INCLUDE_PATH` and the `nvidia-nccl-cuXX` package. + """ + paths: list[str] = [] + inc = envs.VLLM_NCCL_INCLUDE_PATH + if inc and os.path.isdir(inc): + paths.append(inc) + + try: + spec = importlib.util.find_spec("nvidia.nccl") + if spec and getattr(spec, "submodule_search_locations", None): + for loc in spec.submodule_search_locations: + inc_dir = os.path.join(loc, "include") + if os.path.exists(os.path.join(inc_dir, "nccl.h")): + paths.append(inc_dir) + except Exception as e: + logger.debug("Failed to find nccl include path from nvidia.nccl package: %s", e) + + seen: set[str] = set() + out: list[str] = [] + for p in paths: + if p and p not in seen: + out.append(p) + seen.add(p) + return out or None diff --git a/vllm/utils/network_utils.py b/vllm/utils/network_utils.py new file mode 100644 index 000000000000..0a68e48ba5e7 --- /dev/null +++ b/vllm/utils/network_utils.py @@ -0,0 +1,331 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import ipaddress +import os +import socket +import sys +import warnings +from collections.abc import ( + Iterator, + Sequence, +) +from typing import Any +from urllib.parse import urlparse +from uuid import uuid4 + +import psutil +import zmq +import zmq.asyncio + +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def close_sockets(sockets: Sequence[zmq.Socket | zmq.asyncio.Socket]): + for sock in sockets: + if sock is not None: + sock.close(linger=0) + + +def get_ip() -> str: + host_ip = envs.VLLM_HOST_IP + if "HOST_IP" in os.environ and "VLLM_HOST_IP" not in os.environ: + logger.warning( + "The environment variable HOST_IP is deprecated and ignored, as" + " it is often used by Docker and other software to" + " interact with the container's network stack. Please " + "use VLLM_HOST_IP instead to set the IP address for vLLM processes" + " to communicate with each other." + ) + if host_ip: + return host_ip + + # IP is not set, try to get it from the network interface + + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + # try ipv6 + try: + with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as s: + # Google's public DNS server, see + # https://developers.google.com/speed/public-dns/docs/using#addresses + s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + warnings.warn( + "Failed to get the IP address, using 0.0.0.0 by default." + "The value can be set by the environment variable" + " VLLM_HOST_IP or HOST_IP.", + stacklevel=2, + ) + return "0.0.0.0" + + +def test_loopback_bind(address, family): + try: + s = socket.socket(family, socket.SOCK_DGRAM) + s.bind((address, 0)) # Port 0 = auto assign + s.close() + return True + except OSError: + return False + + +def get_loopback_ip() -> str: + loopback_ip = envs.VLLM_LOOPBACK_IP + if loopback_ip: + return loopback_ip + + # VLLM_LOOPBACK_IP is not set, try to get it based on network interface + + if test_loopback_bind("127.0.0.1", socket.AF_INET): + return "127.0.0.1" + elif test_loopback_bind("::1", socket.AF_INET6): + return "::1" + else: + raise RuntimeError( + "Neither 127.0.0.1 nor ::1 are bound to a local interface. " + "Set the VLLM_LOOPBACK_IP environment variable explicitly." + ) + + +def is_valid_ipv6_address(address: str) -> bool: + try: + ipaddress.IPv6Address(address) + return True + except ValueError: + return False + + +def split_host_port(host_port: str) -> tuple[str, int]: + # ipv6 + if host_port.startswith("["): + host, port = host_port.rsplit("]", 1) + host = host[1:] + port = port.split(":")[1] + return host, int(port) + else: + host, port = host_port.split(":") + return host, int(port) + + +def join_host_port(host: str, port: int) -> str: + if is_valid_ipv6_address(host): + return f"[{host}]:{port}" + else: + return f"{host}:{port}" + + +def get_distributed_init_method(ip: str, port: int) -> str: + return get_tcp_uri(ip, port) + + +def get_tcp_uri(ip: str, port: int) -> str: + if is_valid_ipv6_address(ip): + return f"tcp://[{ip}]:{port}" + else: + return f"tcp://{ip}:{port}" + + +def get_open_zmq_ipc_path() -> str: + base_rpc_path = envs.VLLM_RPC_BASE_PATH + return f"ipc://{base_rpc_path}/{uuid4()}" + + +def get_open_zmq_inproc_path() -> str: + return f"inproc://{uuid4()}" + + +def get_open_port() -> int: + """ + Get an open port for the vLLM process to listen on. + An edge case to handle, is when we run data parallel, + we need to avoid ports that are potentially used by + the data parallel master process. + Right now we reserve 10 ports for the data parallel master + process. Currently it uses 2 ports. + """ + if "VLLM_DP_MASTER_PORT" in os.environ: + dp_master_port = envs.VLLM_DP_MASTER_PORT + reserved_port_range = range(dp_master_port, dp_master_port + 10) + while True: + candidate_port = _get_open_port() + if candidate_port not in reserved_port_range: + return candidate_port + return _get_open_port() + + +def get_open_ports_list(count: int = 5) -> list[int]: + """Get a list of open ports.""" + ports = set[int]() + while len(ports) < count: + ports.add(get_open_port()) + return list(ports) + + +def _get_open_port() -> int: + port = envs.VLLM_PORT + if port is not None: + while True: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", port)) + return port + except OSError: + port += 1 # Increment port number if already in use + logger.info("Port %d is already in use, trying port %d", port - 1, port) + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def find_process_using_port(port: int) -> psutil.Process | None: + # TODO: We can not check for running processes with network + # port on macOS. Therefore, we can not have a full graceful shutdown + # of vLLM. For now, let's not look for processes in this case. + # Ref: https://www.florianreinhard.de/accessdenied-in-psutil/ + if sys.platform.startswith("darwin"): + return None + + our_pid = os.getpid() + for conn in psutil.net_connections(): + if conn.laddr.port == port and (conn.pid is not None and conn.pid != our_pid): + try: + return psutil.Process(conn.pid) + except psutil.NoSuchProcess: + return None + return None + + +def split_zmq_path(path: str) -> tuple[str, str, str]: + """Split a zmq path into its parts.""" + parsed = urlparse(path) + if not parsed.scheme: + raise ValueError(f"Invalid zmq path: {path}") + + scheme = parsed.scheme + host = parsed.hostname or "" + port = str(parsed.port or "") + + if scheme == "tcp" and not all((host, port)): + # The host and port fields are required for tcp + raise ValueError(f"Invalid zmq path: {path}") + + if scheme != "tcp" and port: + # port only makes sense with tcp + raise ValueError(f"Invalid zmq path: {path}") + + return scheme, host, port + + +def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str: + """Make a ZMQ path from its parts. + + Args: + scheme: The ZMQ transport scheme (e.g. tcp, ipc, inproc). + host: The host - can be an IPv4 address, IPv6 address, or hostname. + port: Optional port number, only used for TCP sockets. + + Returns: + A properly formatted ZMQ path string. + """ + if port is None: + return f"{scheme}://{host}" + if is_valid_ipv6_address(host): + return f"{scheme}://[{host}]:{port}" + return f"{scheme}://{host}:{port}" + + +# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501 +def make_zmq_socket( + ctx: zmq.asyncio.Context | zmq.Context, # type: ignore[name-defined] + path: str, + socket_type: Any, + bind: bool | None = None, + identity: bytes | None = None, + linger: int | None = None, +) -> zmq.Socket | zmq.asyncio.Socket: # type: ignore[name-defined] + """Make a ZMQ socket with the proper bind/connect semantics.""" + + mem = psutil.virtual_memory() + socket = ctx.socket(socket_type) + + # Calculate buffer size based on system memory + total_mem = mem.total / 1024**3 + available_mem = mem.available / 1024**3 + # For systems with substantial memory (>32GB total, >16GB available): + # - Set a large 0.5GB buffer to improve throughput + # For systems with less memory: + # - Use system default (-1) to avoid excessive memory consumption + buf_size = int(0.5 * 1024**3) if total_mem > 32 and available_mem > 16 else -1 + + if bind is None: + bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB) + + if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER): + socket.setsockopt(zmq.RCVHWM, 0) + socket.setsockopt(zmq.RCVBUF, buf_size) + + if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER): + socket.setsockopt(zmq.SNDHWM, 0) + socket.setsockopt(zmq.SNDBUF, buf_size) + + if identity is not None: + socket.setsockopt(zmq.IDENTITY, identity) + + if linger is not None: + socket.setsockopt(zmq.LINGER, linger) + + if socket_type == zmq.XPUB: + socket.setsockopt(zmq.XPUB_VERBOSE, True) + + # Determine if the path is a TCP socket with an IPv6 address. + # Enable IPv6 on the zmq socket if so. + scheme, host, _ = split_zmq_path(path) + if scheme == "tcp" and is_valid_ipv6_address(host): + socket.setsockopt(zmq.IPV6, 1) + + if bind: + socket.bind(path) + else: + socket.connect(path) + + return socket + + +@contextlib.contextmanager +def zmq_socket_ctx( + path: str, + socket_type: Any, + bind: bool | None = None, + linger: int = 0, + identity: bytes | None = None, +) -> Iterator[zmq.Socket]: + """Context manager for a ZMQ socket""" + + ctx = zmq.Context() # type: ignore[attr-defined] + try: + yield make_zmq_socket(ctx, path, socket_type, bind=bind, identity=identity) + except KeyboardInterrupt: + logger.debug("Got Keyboard Interrupt.") + + finally: + ctx.destroy(linger=linger) diff --git a/vllm/utils/platform_utils.py b/vllm/utils/platform_utils.py new file mode 100644 index 000000000000..34ac820c6e9d --- /dev/null +++ b/vllm/utils/platform_utils.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import multiprocessing +from collections.abc import Sequence +from concurrent.futures.process import ProcessPoolExecutor +from functools import cache +from typing import Any + +import torch + + +def cuda_is_initialized() -> bool: + """Check if CUDA is initialized.""" + if not torch.cuda._is_compiled(): + return False + return torch.cuda.is_initialized() + + +def xpu_is_initialized() -> bool: + """Check if XPU is initialized.""" + if not torch.xpu._is_compiled(): + return False + return torch.xpu.is_initialized() + + +def cuda_get_device_properties( + device, names: Sequence[str], init_cuda=False +) -> tuple[Any, ...]: + """Get specified CUDA device property values without initializing CUDA in + the current process.""" + if init_cuda or cuda_is_initialized(): + props = torch.cuda.get_device_properties(device) + return tuple(getattr(props, name) for name in names) + + # Run in subprocess to avoid initializing CUDA as a side effect. + mp_ctx = multiprocessing.get_context("fork") + with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor: + return executor.submit(cuda_get_device_properties, device, names, True).result() + + +@cache +def is_pin_memory_available() -> bool: + from vllm.platforms import current_platform + + return current_platform.is_pin_memory_available() + + +@cache +def is_uva_available() -> bool: + """Check if Unified Virtual Addressing (UVA) is available.""" + # UVA requires pinned memory. + # TODO: Add more requirements for UVA if needed. + return is_pin_memory_available() diff --git a/vllm/utils/profiling.py b/vllm/utils/profiling.py new file mode 100644 index 000000000000..b66910693957 --- /dev/null +++ b/vllm/utils/profiling.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import contextlib +from collections.abc import Callable +from functools import wraps +from typing import Any + + +@contextlib.contextmanager +def cprofile_context(save_file: str | None = None): + """Run a cprofile + + Args: + save_file: path to save the profile result. "1" or + None will result in printing to stdout. + """ + import cProfile + + prof = cProfile.Profile() + prof.enable() + + try: + yield + finally: + prof.disable() + if save_file and save_file != "1": + prof.dump_stats(save_file) + else: + prof.print_stats(sort="cumtime") + + +def cprofile(save_file: str | None = None, enabled: bool = True): + """Decorator to profile a Python method using cProfile. + + Args: + save_file: Path to save the profile result. + If "1", None, or "", results will be printed to stdout. + enabled: Set to false to turn this into a no-op + """ + + def decorator(func: Callable): + @wraps(func) + def wrapper(*args: Any, **kwargs: Any): + if not enabled: + # If profiling is disabled, just call the function directly. + return func(*args, **kwargs) + + with cprofile_context(save_file): + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/vllm/utils/serial_utils.py b/vllm/utils/serial_utils.py new file mode 100644 index 000000000000..b89fa6ce4db6 --- /dev/null +++ b/vllm/utils/serial_utils.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import base64 +import sys +from dataclasses import dataclass +from typing import Literal + +import numpy as np +import torch +from typing_extensions import assert_never + +from vllm import PoolingRequestOutput + +sys_byteorder = sys.byteorder + + +EMBED_DTYPE_TO_TORCH_DTYPE = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + # I'm not sure if other platforms' CPUs support the fp8 data format. + # EMBED_DTYPE only uses the fp8 data representation, + # does not use fp8 computation, and only occurs on the CPU. + # Apologize for any possible break. + "fp8_e4m3": torch.float8_e4m3fn, + "fp8_e5m2": torch.float8_e5m2, +} + + +EMBED_DTYPE_TO_TORCH_DTYPE_VIEW = { + "float32": torch.float32, + "float16": torch.float16, + # numpy does not support bfloat16 and fp8 + "bfloat16": torch.float16, + "fp8_e4m3": torch.uint8, + "fp8_e5m2": torch.uint8, +} + +EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW = { + "float32": np.float32, + "float16": np.float16, + # numpy does not support bfloat16 and fp8 + "bfloat16": np.float16, + "fp8_e4m3": np.uint8, + "fp8_e5m2": np.uint8, +} + +ENDIANNESS = ["native", "big", "little"] + +EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"] +Endianness = Literal["native", "big", "little"] +EncodingFormat = Literal["float", "base64", "bytes"] + + +def tensor2binary( + tensor: torch.Tensor, embed_dtype: EmbedDType, endianness: Endianness +) -> bytes: + assert isinstance(tensor, torch.Tensor) + assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE + assert endianness in ENDIANNESS + + torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype] + torch_view_dtype = EMBED_DTYPE_TO_TORCH_DTYPE_VIEW[embed_dtype] + + np_array = ( + tensor.to(torch_dtype).flatten().contiguous().view(torch_view_dtype).numpy() + ) + + if endianness != "native" and endianness != sys_byteorder: + np_array = np_array.byteswap() + + return np_array.tobytes() + + +def binary2tensor( + binary: bytes, + shape: tuple[int, ...], + embed_dtype: EmbedDType, + endianness: Endianness, +) -> torch.Tensor: + assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE + assert embed_dtype in EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW + assert endianness in ENDIANNESS + + torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype] + np_dtype = EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW[embed_dtype] + + np_array = np.frombuffer(binary, dtype=np_dtype).reshape(shape) + + if endianness != "native" and endianness != sys_byteorder: + np_array = np_array.byteswap() + + return torch.from_numpy(np_array).view(torch_dtype) + + +def encode_pooling_output( + output: PoolingRequestOutput, + encoding_format: EncodingFormat, + embed_dtype: EmbedDType, + endianness: Endianness, +) -> list[float] | str | bytes: + if encoding_format == "float": + return output.outputs.data.tolist() + elif encoding_format == "base64": + embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness) + return base64.b64encode(embedding_bytes).decode("utf-8") + elif encoding_format == "bytes": + return tensor2binary(output.outputs.data, embed_dtype, endianness) + assert_never(encoding_format) + + +@dataclass +class MetadataItem: + index: int + embed_dtype: EmbedDType + endianness: Endianness + start: int + end: int + shape: tuple[int, ...] + + +def encode_pooling_bytes( + pooling_outputs: list[PoolingRequestOutput], + embed_dtype: EmbedDType, + endianness: Endianness, +): + num_prompt_tokens = 0 + items: list[dict[str, MetadataItem]] = [] + body = [] + offset = 0 + for idx, output in enumerate(pooling_outputs): + binary = tensor2binary( + tensor=output.outputs.data, + embed_dtype=embed_dtype, + endianness=endianness, + ) + size = len(binary) + + item = { + "index": idx, + "embed_dtype": embed_dtype, + "endianness": endianness, + "start": offset, + "end": offset + size, + "shape": output.outputs.data.shape, + } + + body.append(binary) + items.append(item) + prompt_token_ids = output.prompt_token_ids + num_prompt_tokens += len(prompt_token_ids) + offset += size + + usage = { + "prompt_tokens": num_prompt_tokens, + "total_tokens": num_prompt_tokens, + } + return body, items, usage + + +def decode_pooling_output(items: list[MetadataItem], body: bytes) -> list[torch.Tensor]: + items.sort(key=lambda x: x.index) + + tensor_list: list[torch.Tensor] = [] + for item in items: + binary = body[item.start : item.end] + tensor = binary2tensor(binary, item.shape, item.embed_dtype, item.endianness) + tensor_list.append(tensor) + return tensor_list diff --git a/vllm/utils/system_utils.py b/vllm/utils/system_utils.py new file mode 100644 index 000000000000..dd18adf55e1f --- /dev/null +++ b/vllm/utils/system_utils.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import contextlib +import os +import sys +from collections.abc import Callable, Iterator +from pathlib import Path +from typing import TextIO + +try: + import setproctitle +except ImportError: + setproctitle = None # type: ignore[assignment] + +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + +CYAN = "\033[1;36m" +RESET = "\033[0;0m" + + +# Environment variable utilities + + +def update_environment_variables(envs_dict: dict[str, str]): + """Update multiple environment variables with logging.""" + for k, v in envs_dict.items(): + if k in os.environ and os.environ[k] != v: + logger.warning( + "Overwriting environment variable %s from '%s' to '%s'", + k, + os.environ[k], + v, + ) + os.environ[k] = v + + +@contextlib.contextmanager +def set_env_var(key: str, value: str) -> Iterator[None]: + """Temporarily set an environment variable.""" + old = os.environ.get(key) + os.environ[key] = value + try: + yield + finally: + if old is None: + os.environ.pop(key, None) + else: + os.environ[key] = old + + +# File path utilities + + +def unique_filepath(fn: Callable[[int], Path]) -> Path: + """Generate a unique file path by trying incrementing integers. + + Note: This function has a TOCTOU race condition. + Caller should use atomic operations (e.g., open with 'x' mode) + when creating the file to ensure thread safety. + """ + i = 0 + while True: + p = fn(i) + if not p.exists(): + return p + i += 1 + + +# Process management utilities + + +def set_process_title( + name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX +) -> None: + """Set the current process title with optional suffix.""" + if setproctitle is None: + return + if suffix: + name = f"{name}_{suffix}" + setproctitle.setproctitle(f"{prefix}::{name}") + + +def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: + """Add colored prefix to file output for log decoration.""" + prefix = f"{CYAN}({worker_name} pid={pid}){RESET} " + file_write = file.write + + def write_with_prefix(s: str): + if not s: + return + if file.start_new_line: # type: ignore[attr-defined] + file_write(prefix) + idx = 0 + while (next_idx := s.find("\n", idx)) != -1: + next_idx += 1 + file_write(s[idx:next_idx]) + if next_idx == len(s): + file.start_new_line = True # type: ignore[attr-defined] + return + file_write(prefix) + idx = next_idx + file_write(s[idx:]) + file.start_new_line = False # type: ignore[attr-defined] + + file.start_new_line = True # type: ignore[attr-defined] + file.write = write_with_prefix # type: ignore[method-assign] + + +def decorate_logs(process_name: str | None = None) -> None: + """Decorate stdout/stderr with process name and PID prefix.""" + from vllm.utils import get_mp_context + + if process_name is None: + process_name = get_mp_context().current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) diff --git a/vllm/utils/tensor_schema.py b/vllm/utils/tensor_schema.py index e17676ccf7ef..526dfd38bac4 100644 --- a/vllm/utils/tensor_schema.py +++ b/vllm/utils/tensor_schema.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints +from types import UnionType +from typing import Annotated, Any, Union, get_args, get_origin, get_type_hints import torch @@ -12,16 +13,16 @@ class TensorShape: def __init__( self, - *dims: Union[int, str], - dynamic_dims: Optional[set[str]] = None, + *dims: int | str, + dynamic_dims: set[str] | None = None, ) -> None: super().__init__() self.dims = dims self.dynamic_dims = dynamic_dims if dynamic_dims else set() - def resolve(self, **bindings: int) -> tuple[Union[int, str], ...]: - resolved = list[Union[int, str]]() + def resolve(self, **bindings: int) -> tuple[int | str, ...]: + resolved = list[int | str]() for dim in self.dims: if isinstance(dim, str) and dim in bindings: resolved.append(bindings[dim]) @@ -48,7 +49,7 @@ def __init__( self, *, validate: bool = True, - resolve_bindings: Optional[dict[str, int]] = None, + resolve_bindings: dict[str, int] | None = None, **kwargs: Any, ) -> None: super().__init__() @@ -71,7 +72,7 @@ def _match_shape_with_dynamic( self, actual: tuple[int, ...], reference: tuple[int, ...], - expected_shape: tuple[Union[int, str], ...], + expected_shape: tuple[int | str, ...], dynamic_dims: set[str], ) -> bool: if len(actual) != len(reference) or len(actual) > len(expected_shape): @@ -100,7 +101,7 @@ def _validate_field( self, value: object, field_name: str, - expected_shape: tuple[Union[int, str], ...], + expected_shape: tuple[int | str, ...], dynamic_dims: set[str], leading_idxs: tuple[int, ...] = (), ) -> tuple[int, ...]: @@ -154,7 +155,7 @@ def _validate_field( def _validate_tensor_shape_expected( self, actual_shape: tuple[int, ...], - expected_shape: tuple[Union[int, str], ...], + expected_shape: tuple[int | str, ...], field_name: str, shape_env: dict[str, int], dynamic_dims: set[str], @@ -209,7 +210,8 @@ def validate(self) -> None: actual_type = args[0] # Check arg was provided as Union - if get_origin(actual_type) is Union: + if get_origin(actual_type) in {Union, UnionType}: + # Union for Union[X, Y] and UnionType for X | Y args = get_args(actual_type) # Skip validation when Union contains None if type(None) in args: diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py new file mode 100644 index 000000000000..adcacb34cb7c --- /dev/null +++ b/vllm/utils/torch_utils.py @@ -0,0 +1,605 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import importlib.metadata +import threading +from collections.abc import Callable, Collection +from functools import lru_cache +from typing import TYPE_CHECKING, Any, TypeVar + +import numpy as np +import numpy.typing as npt +import torch +from packaging import version +from packaging.version import Version +from torch.library import Library + +import vllm.envs as envs + +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.sequence import IntermediateTensors +else: + ModelConfig = object + IntermediateTensors = object + + +STR_DTYPE_TO_TORCH_DTYPE = { + "float32": torch.float32, + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, + "fp8": torch.uint8, + "fp8_e4m3": torch.uint8, + "fp8_e5m2": torch.uint8, + "int8": torch.int8, + "fp8_inc": torch.float8_e4m3fn, + "fp8_ds_mla": torch.uint8, +} + +TORCH_DTYPE_TO_NUMPY_DTYPE = { + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, + torch.uint8: np.uint8, + torch.int32: np.int32, + torch.int64: np.int64, +} + + +T = TypeVar("T") + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +@contextlib.contextmanager +def set_default_torch_num_threads(num_threads: int): + """Sets the default number of threads for PyTorch to the given value.""" + old_num_threads = torch.get_num_threads() + torch.set_num_threads(num_threads) + yield + torch.set_num_threads(old_num_threads) + + +def get_dtype_size(dtype: torch.dtype) -> int: + """Get the size of the data type in bytes.""" + return torch.tensor([], dtype=dtype).element_size() + + +# bool = 0, int = 1, float = 2, complex = 3 +def _get_precision_level(dtype: torch.dtype) -> int: + # NOTE: Complex dtypes return `is_floating_point=False` + return (dtype != torch.bool) + dtype.is_floating_point + dtype.is_complex * 2 + + +def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): + """ + Test whether it is lossless to cast a tensor from + `src_dtype` to `tgt_dtype`. + """ + if src_dtype == tgt_dtype: + return True + + src_level = _get_precision_level(src_dtype) + tgt_level = _get_precision_level(tgt_dtype) + + if src_level < tgt_level: + return True + if src_level > tgt_level: + return False + + # Compare integral types + if not src_dtype.is_floating_point and not src_dtype.is_complex: + src_info = torch.iinfo(src_dtype) + tgt_info = torch.iinfo(tgt_dtype) + return src_info.min >= tgt_info.min and src_info.max <= tgt_info.max + + # Compare floating-point types + src_info = torch.finfo(src_dtype) + tgt_info = torch.finfo(tgt_dtype) + return ( + src_info.min >= tgt_info.min + and src_info.max <= tgt_info.max + and src_info.resolution >= tgt_info.resolution + ) + + +def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): + """ + Get the common `dtype` where all of the other `dtypes` can be + cast to it without losing any information. + """ + return max( + dtypes, + key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes), + ) + + +def _generate_random_fp8( + tensor: torch.Tensor, + low: float, + high: float, +) -> None: + # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type, + # it may occur Inf or NaN if we directly use torch.randint + # to generate random data for fp8 data. + # For example, s.11111.00 in fp8e5m2 format represents Inf. + # | E4M3 | E5M2 + # -----|-------------|------------------- + # Inf | N/A | s.11111.00 + # NaN | s.1111.111 | s.11111.{01,10,11} + from vllm import _custom_ops as ops + + tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) + tensor_tmp.uniform_(low, high) + ops.convert_fp8(tensor, tensor_tmp) + del tensor_tmp + + +def get_kv_cache_torch_dtype( + cache_dtype: str | torch.dtype | None, + model_dtype: str | torch.dtype | None = None, +) -> torch.dtype: + if isinstance(cache_dtype, str): + if cache_dtype == "auto": + if isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] + elif isinstance(model_dtype, torch.dtype): + torch_dtype = model_dtype + else: + raise ValueError(f"Invalid model dtype: {model_dtype}") + elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + elif isinstance(cache_dtype, torch.dtype): + torch_dtype = cache_dtype + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + return torch_dtype + + +def kv_cache_dtype_str_to_dtype( + kv_cache_dtype: str, model_config: ModelConfig +) -> torch.dtype: + if kv_cache_dtype == "auto": + # Model config may not be specified for unit tests, default to float16 + return model_config.dtype if model_config else torch.half + return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] + + +def create_kv_caches_with_random_flash( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: str | torch.dtype | None, + model_dtype: str | torch.dtype | None = None, + seed: int | None = None, + device: str | None = "cuda", + cache_layout: str | None = "NHD", +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + from vllm.platforms import current_platform + + current_platform.seed_everything(seed) + + dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + assert cache_layout in ("NHD", "HND") + stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4) + + kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] for i in stride_order) + scale = head_size**-0.5 + + key_caches: list[torch.Tensor] = [] + value_caches: list[torch.Tensor] = [] + + for _ in range(num_layers): + key_value_cache = torch.empty( + size=kv_cache_allocation_shape, dtype=dtype, device=device + ).permute(*stride_order) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_value_cache.uniform_(-scale, scale) + elif cache_dtype == "fp8": + _generate_random_fp8(key_value_cache, -scale, scale) + else: + raise ValueError(f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_value_cache[:, 0]) + value_caches.append(key_value_cache[:, 1]) + return key_caches, value_caches + + +def create_kv_caches_with_random( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: str | torch.dtype | None, + model_dtype: str | torch.dtype | None = None, + seed: int | None = None, + device: str | None = "cuda", +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + if cache_dtype == "fp8" and head_size % 16: + raise ValueError( + f"Does not support key cache of type fp8 with head_size {head_size}" + ) + from vllm.platforms import current_platform + + current_platform.seed_everything(seed) + + dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + + scale = head_size**-0.5 + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_caches: list[torch.Tensor] = [] + for _ in range(num_layers): + key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_cache.uniform_(-scale, scale) + elif cache_dtype == "fp8": + _generate_random_fp8(key_cache, -scale, scale) + else: + raise ValueError(f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches: list[torch.Tensor] = [] + for _ in range(num_layers): + value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + value_cache.uniform_(-scale, scale) + elif cache_dtype == "fp8": + _generate_random_fp8(value_cache, -scale, scale) + else: + raise ValueError(f"Does not support value cache of type {cache_dtype}") + value_caches.append(value_cache) + return key_caches, value_caches + + +def async_tensor_h2d( + data: list, + dtype: torch.dtype, + target_device: str | torch.device, + pin_memory: bool, +) -> torch.Tensor: + """Asynchronously create a tensor and copy it from host to device.""" + t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") + return t.to(device=target_device, non_blocking=True) + + +def make_ndarray_with_pad( + x: list[list[T]], + pad: T, + dtype: npt.DTypeLike, + *, + max_len: int | None = None, +) -> npt.NDArray: + """ + Make a padded array from 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + if max_len is None: + # Unlike for most functions, map is faster than a genexpr over `len` + max_len = max(map(len, x), default=0) + + padded_x = np.full((len(x), max_len), pad, dtype=dtype) + for ind, blocktb in enumerate(x): + assert len(blocktb) <= max_len + padded_x[ind, : len(blocktb)] = blocktb + + return padded_x + + +def make_tensor_with_pad( + x: list[list[T]], + pad: T, + dtype: torch.dtype, + *, + max_len: int | None = None, + device: str | torch.device | None = None, + pin_memory: bool = False, +) -> torch.Tensor: + """ + Make a padded tensor from 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] + padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len) + + tensor = torch.from_numpy(padded_x).to(device) + if pin_memory: + tensor = tensor.pin_memory() + + return tensor + + +prev_set_stream = torch.cuda.set_stream + +_current_stream_tls = threading.local() + + +def _patched_set_stream(stream: torch.cuda.Stream) -> None: + _current_stream_tls.value = stream + prev_set_stream(stream) + + +torch.cuda.set_stream = _patched_set_stream + + +class _StreamPlaceholder: + def __init__(self): + self.synchronize = lambda: None + + +def current_stream() -> torch.cuda.Stream: + """ + replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`. + it turns out that `torch.cuda.current_stream()` is quite expensive, + as it will construct a new stream object at each call. + here we patch `torch.cuda.set_stream` to keep track of the current stream + directly, so that we can avoid calling `torch.cuda.current_stream()`. + + the underlying hypothesis is that we do not call `torch._C._cuda_setStream` + from C/C++ code. + """ + from vllm.platforms import current_platform + + if not hasattr(_current_stream_tls, "value") or _current_stream_tls.value is None: + # when this function is called before any stream is set, + # we return the default stream. + # On ROCm using the default 0 stream in combination with RCCL + # is hurting performance. Therefore creating a dedicated stream + # per process + if current_platform.is_rocm(): + # torch.cuda.set_stream here is the alias of _pathed_set_stream + torch.cuda.set_stream(torch.cuda.Stream()) + elif current_platform.is_cpu(): + _current_stream_tls.value = _StreamPlaceholder() + else: + current_stream = current_platform.current_stream + if current_stream is not None: + _current_stream_tls.value = current_stream() + else: + raise ValueError( + "Fail to set current stream, current platform " + "may not support current_stream with torch API" + ) + return _current_stream_tls.value + + +@lru_cache(maxsize=8) +def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int: + # Note: cuda_visible_devices is not used, but we keep it as an argument for + # LRU Cache purposes. + + # Code below is based on + # https://github.com/pytorch/pytorch/blob/ + # c1cd946818442aca8c7f812b16d187ce1586c3bc/ + # torch/cuda/__init__.py#L831C1-L831C17 + import torch.cuda + import torch.version + + from vllm.platforms import current_platform + + if not torch.cuda._is_compiled(): + return 0 + if current_platform.is_rocm(): + # ROCm uses amdsmi instead of nvml for stateless device count + # This requires a sufficiently modern version of Torch 2.4.0 + raw_count = ( + torch.cuda._device_count_amdsmi() + if (hasattr(torch.cuda, "_device_count_amdsmi")) + else -1 + ) + else: + raw_count = torch.cuda._device_count_nvml() + r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count + return r + + +def cuda_device_count_stateless() -> int: + """Get number of CUDA devices, caching based on the value of + CUDA_VISIBLE_DEVICES at the time of call. + + This should be used instead of torch.cuda.device_count() + unless CUDA_VISIBLE_DEVICES has already been set to the desired + value.""" + + # This can be removed and simply replaced with torch.cuda.get_device_count + # after https://github.com/pytorch/pytorch/pull/122815 is released. + return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) + + +def weak_ref_tensor(tensor: Any) -> Any: + """ + Create a weak reference to a tensor. + The new tensor will share the same data as the original tensor, + but will not keep the original tensor alive. + """ + if isinstance(tensor, torch.Tensor): + return torch.ops._C.weak_ref_tensor(tensor) + else: + return tensor + + +def weak_ref_tensors( + tensors: torch.Tensor + | list[torch.Tensor] + | tuple[torch.Tensor] + | IntermediateTensors, +) -> torch.Tensor | list[Any] | tuple[Any] | Any: + """ + Convenience function to create weak references to tensors, + for single tensor, list of tensors or tuple of tensors. + """ + if isinstance(tensors, torch.Tensor): + return weak_ref_tensor(tensors) + if isinstance(tensors, list): + return [weak_ref_tensor(t) for t in tensors] + if isinstance(tensors, tuple): + return tuple(weak_ref_tensor(t) for t in tensors) + + # For IntermediateTensors used in pipeline parallelism + from vllm.sequence import IntermediateTensors + + if isinstance(tensors, IntermediateTensors): + ret = IntermediateTensors( + {key: weak_ref_tensor(val) for key, val in tensors.tensors.items()} + ) + return ret + raise ValueError("Invalid type for tensors") + + +def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor: + """ + Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA). + """ + assert cpu_tensor.is_pinned(), "CPU tensor must be pinned" + return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) + + +# Helper function used in testing. +def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool: + torch_version = version.parse(torch_version) + return torch_version >= version.parse(target) + + +def is_torch_equal_or_newer(target: str) -> bool: + """Check if the installed torch version is >= the target version. + + Args: + target: a version string, like "2.6.0". + + Returns: + Whether the condition meets. + """ + try: + return _is_torch_equal_or_newer(str(torch.__version__), target) + except Exception: + # Fallback to PKG-INFO to load the package info, needed by the doc gen. + return Version(importlib.metadata.version("torch")) >= Version(target) + + +def _is_torch_equal(target: str) -> bool: + assert target.count(".") == 2 + torch_version = str(torch.__version__) + torch_version = version.parse(torch_version) + # torch version is like "2.6.0.dev20240101" or "2.6.0.dev20240101+cpu" + # or "2.6.0+cu128" but never "2.6.0.1" + return ( + torch_version >= version.parse(target) + and version.parse(target + ".1") > torch_version + ) + + +def is_torch_equal(target: str) -> bool: + """Check if the installed torch version is == the target version. + + Args: + target: a version string, like "2.6.0". + + Returns: + Whether the condition meets. + """ + try: + return _is_torch_equal(target) + except Exception: + return Version(importlib.metadata.version("torch")) == Version(target) + + +# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. +# In particular, the FakeScalarType is not supported for earlier versions of +# PyTorch which breaks dynamo for any ops registered using ScalarType. +def supports_dynamo() -> bool: + return is_torch_equal_or_newer("2.4.0") + + +# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform +def supports_xccl() -> bool: + return ( + is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available() + ) + + +# Some backends use pytorch version < 2.4.0 which doesn't +# support `torch.library.custom_op`. +def supports_custom_op() -> bool: + return hasattr(torch.library, "custom_op") + + +# create a library to hold the custom op +vllm_lib = Library("vllm", "FRAGMENT") # noqa + + +def direct_register_custom_op( + op_name: str, + op_func: Callable, + mutates_args: list[str] | None = None, + fake_impl: Callable | None = None, + target_lib: Library | None = None, + dispatch_key: str | None = None, + tags: tuple[torch.Tag, ...] = (), +): + """ + `torch.library.custom_op` can have significant overhead because it + needs to consider complicated dispatching logic. This function + directly registers a custom op and dispatches it to the CUDA backend. + See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 + for more details. + + By default, the custom op is registered to the vLLM library. If you + want to register it to a different library, you can pass the library + object to the `target_lib` argument. + + IMPORTANT: the lifetime of the operator is tied to the lifetime of the + library object. If you want to bind the operator to a different library, + make sure the library object is alive when the operator is used. + """ + if not supports_custom_op(): + from vllm.platforms import current_platform + + assert not current_platform.is_cuda_alike(), ( + "cuda platform needs torch>=2.4 to support custom op, " + "chances are you are using an old version of pytorch " + "or a custom build of pytorch. It is recommended to " + "use vLLM in a fresh new environment and let it install " + "the required dependencies." + ) + return + + if mutates_args is None: + mutates_args = [] + + if dispatch_key is None: + from vllm.platforms import current_platform + + dispatch_key = current_platform.dispatch_key + + import torch.library + + if hasattr(torch.library, "infer_schema"): + schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) + else: + # for pytorch 2.4 + import torch._custom_op.impl + + schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) + my_lib = target_lib or vllm_lib + my_lib.define(op_name + schema_str, tags=tags) + my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) + if fake_impl is not None: + my_lib._register_fake(op_name, fake_impl) diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 6e27e93c9115..0d3e1729ff20 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -110,7 +110,7 @@ class TorchSDPAMetadata(AttentionMetadata): """Metadata for PagedAttention.""" # (batch_size,). The length of sequences (entire tokens seen so far) per # sequence. - decode_seq_lens_tensor: Optional[torch.Tensor] + decode_seq_lens_tensor: torch.Tensor | None # Maximum sequence length in the batch. 0 if it is prefill-only batch. decode_max_seq_len: int # (batch_size, max_blocks_per_seq). @@ -119,39 +119,39 @@ class TorchSDPAMetadata(AttentionMetadata): # in the kv cache. Each block can contain up to block_size tokens. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. - decode_block_tables: Optional[torch.Tensor] + decode_block_tables: torch.Tensor | None """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. chunked_prefill: bool - seq_lens: Optional[list[int]] = None # For non-chunked prefill + seq_lens: list[int] | None = None # For non-chunked prefill # For chunked prefill only - max_query_len: Optional[int] = None - prefill_max_seq_len: Optional[int] = None - prefill_query_start_loc: Optional[torch.Tensor] = None - prefill_seq_start_loc: Optional[torch.Tensor] = None - prefill_block_tables: Optional[torch.Tensor] = None + max_query_len: int | None = None + prefill_max_seq_len: int | None = None + prefill_query_start_loc: torch.Tensor | None = None + prefill_seq_start_loc: torch.Tensor | None = None + prefill_block_tables: torch.Tensor | None = None # For V1 logits index only - query_start_loc: Optional[torch.Tensor] = None + query_start_loc: torch.Tensor | None = None # Begin encoder attn & enc/dec cross-attn fields... # Encoder sequence lengths representation - encoder_seq_lens: Optional[list[int]] = None - encoder_seq_lens_tensor: Optional[torch.Tensor] = None + encoder_seq_lens: list[int] | None = None + encoder_seq_lens_tensor: torch.Tensor | None = None # Maximum sequence length among encoder sequences - max_encoder_seq_len: Optional[int] = None + max_encoder_seq_len: int | None = None # Number of tokens input to encoder - num_encoder_tokens: Optional[int] = None + num_encoder_tokens: int | None = None # Cross-attention memory-mapping data structures: slot mapping # and block tables - cross_slot_mapping: Optional[torch.Tensor] = None - cross_block_tables: Optional[torch.Tensor] = None + cross_slot_mapping: torch.Tensor | None = None + cross_block_tables: torch.Tensor | None = None def __post_init__(self): # Set during the execution of the first attention op. @@ -159,9 +159,9 @@ def __post_init__(self): # when alibi slopes is used. It is because of the limitation # from xformer API. # will not appear in the __repr__ and __init__ - self.attn_bias: Optional[list[torch.Tensor]] = None - self.encoder_attn_bias: Optional[list[torch.Tensor]] = None - self.cross_attn_bias: Optional[list[torch.Tensor]] = None + self.attn_bias: list[torch.Tensor] | None = None + self.encoder_attn_bias: list[torch.Tensor] | None = None + self.cross_attn_bias: list[torch.Tensor] | None = None @property def is_all_encoder_attn_metadata_set(self): @@ -237,7 +237,7 @@ def get_seq_lens( def get_attn_bias( self, attn_type: str, - ) -> Optional[list[torch.Tensor]]: + ) -> list[torch.Tensor] | None: """ Extract appropriate attention bias from attention metadata according to attention type. @@ -412,7 +412,7 @@ def build( num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, # to ensure inference when chunked_prefill is disabled - seq_lens=seq_lens_cpu.tolist(), + seq_lens=seq_lens_cpu.tolist()[num_decodes:], # prefill decode_seq_lens_tensor=seq_lens_cpu[:num_decodes], # decode decode_max_seq_len=max_decode_seq_len, # decode decode_block_tables=block_table_tensor[:num_decodes], # decode @@ -439,12 +439,12 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, + kv_sharing_target_layer_name: str | None = None, ) -> None: if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported in V0.") @@ -484,9 +484,9 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: TorchSDPAMetadata, # type: ignore - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -617,7 +617,6 @@ def forward( prefill_meta.prefill_block_tables, self.alibi_slopes, ) - if decode_meta := attn_metadata.decode_metadata: assert attn_type != AttentionType.ENCODER_ONLY, ( "Encoder-only models should not have decode metadata." @@ -686,7 +685,12 @@ def _run_sdpa_forward( causal_attn = attn_type == AttentionType.DECODER seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) - start_q, start_kv = 0, 0 + # Incoming Q and KV contain decoded tokens as well, hence start at an offset + # equal to num_decode_tokens since decode requests appear first + start_q, start_kv = ( + attn_metadata.num_decode_tokens, + attn_metadata.num_decode_tokens, + ) for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, attn_masks): end_q = start_q + seq_len_q end_kv = start_kv + seq_len_kv @@ -737,7 +741,7 @@ def _make_alibi_bias( def _make_sliding_window_bias( seq_lens: list[int], - window_size: Optional[int], + window_size: int | None, dtype: torch.dtype, ) -> list[torch.Tensor]: attn_biases: list[torch.Tensor] = [] @@ -824,7 +828,7 @@ def forward_decode( kv_cache_dtype: str, num_kv_heads: int, scale: float, - alibi_slopes: Optional[torch.Tensor], + alibi_slopes: torch.Tensor | None, k_scale: torch.Tensor, v_scale: torch.Tensor, *args, @@ -907,7 +911,7 @@ def forward_decode( kv_cache_dtype: str, num_kv_heads: int, scale: float, - alibi_slopes: Optional[torch.Tensor], + alibi_slopes: torch.Tensor | None, k_scale: torch.Tensor, v_scale: torch.Tensor, *args, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 1f6b7e41b37e..720fbd2c15c5 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -3,7 +3,6 @@ """Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import Optional import numpy as np import torch @@ -14,9 +13,11 @@ AttentionImpl, AttentionMetadata, AttentionType, + MultipleOf, is_quantized_kv_cache, ) from vllm.attention.layer import Attention +from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import ( flash_attn_supports_fp8, @@ -30,9 +31,12 @@ get_scheduler_metadata, reshape_and_cache_flash, ) - from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( AttentionCGSupport, @@ -47,7 +51,6 @@ class FlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True - supports_quant_query_input: bool = True @classmethod def get_supported_dtypes(cls) -> list[torch.dtype]: @@ -57,6 +60,10 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] + @staticmethod + def get_supported_kernel_block_size() -> list[int | MultipleOf]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() @@ -139,13 +146,17 @@ class FlashAttentionMetadata: # For cascade attention. use_cascade: bool common_prefix_len: int - cu_prefix_query_lens: Optional[torch.Tensor] - prefix_kv_lens: Optional[torch.Tensor] - suffix_kv_lens: Optional[torch.Tensor] + cu_prefix_query_lens: torch.Tensor | None + prefix_kv_lens: torch.Tensor | None + suffix_kv_lens: torch.Tensor | None + + # For GQA DCP + max_dcp_context_kv_len: int | None = None + dcp_context_kv_lens: torch.Tensor | None = None # Optional aot scheduling - scheduler_metadata: Optional[torch.Tensor] = None - prefix_scheduler_metadata: Optional[torch.Tensor] = None + scheduler_metadata: torch.Tensor | None = None + prefix_scheduler_metadata: torch.Tensor | None = None max_num_splits: int = 0 causal: bool = True @@ -153,9 +164,9 @@ class FlashAttentionMetadata: def _get_sliding_window_configs( vllm_config: VllmConfig, -) -> set[Optional[tuple[int, int]]]: +) -> set[tuple[int, int] | None]: """Get the set of all sliding window configs used in the model.""" - sliding_window_configs: set[Optional[tuple[int, int]]] = set() + sliding_window_configs: set[tuple[int, int] | None] = set() layers = get_layers_from_vllm_config(vllm_config, Attention) for layer in layers.values(): assert isinstance(layer.impl, FlashAttentionImpl) @@ -212,10 +223,20 @@ def __init__( self.max_num_splits = 0 # No upper bound on the number of splits. self.aot_schedule = get_flash_attn_version() == 3 + try: + from vllm.distributed.parallel_state import get_dcp_group + + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) - self.max_cudagraph_size = self.compilation_config.max_capture_size + self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size if self.use_full_cuda_graph and self.aot_schedule: if self.max_cudagraph_size > 992: @@ -237,7 +258,7 @@ def __init__( # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. - self.aot_sliding_window: Optional[tuple[int, int]] = None + self.aot_sliding_window: tuple[int, int] | None = None def build( self, @@ -287,6 +308,9 @@ def build( # we only set num_splits when using cuda graphs. max_num_splits = self.max_num_splits + if vllm_is_batch_invariant(): + max_num_splits = 1 + def schedule( batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal ): @@ -302,7 +326,7 @@ def schedule( batch_size=batch_size, max_seqlen_q=max_query_len, max_seqlen_k=max_seq_len, - num_heads_q=self.num_heads_q, + num_heads_q=self.num_heads_q * self.dcp_world_size, num_heads_kv=self.num_heads_kv, headdim=self.headdim, cache_seqlens=seqlens, @@ -316,8 +340,35 @@ def schedule( return None use_cascade = common_prefix_len > 0 + max_dcp_context_kv_len = 0 + dcp_context_kv_lens = None + + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + prefix_scheduler_metadata = None + + if self.dcp_world_size > 1: + query_kv_lens_cpu = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] + ) + dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu + dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self.dcp_world_size + ( + self.dcp_rank <= (dcp_context_kv_lens_cpu - 1) % self.dcp_world_size + ) + dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device) + max_dcp_context_kv_len = dcp_context_kv_lens.max().item() - if use_cascade: + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=dcp_context_kv_lens, + max_seq_len=max_dcp_context_kv_len, + causal=False, + ) + elif use_cascade: cu_prefix_query_lens = torch.tensor( [0, num_actual_tokens], dtype=torch.int32, device=self.device ) @@ -344,10 +395,6 @@ def schedule( causal=True, ) else: - cu_prefix_query_lens = None - prefix_kv_lens = None - suffix_kv_lens = None - prefix_scheduler_metadata = None scheduler_metadata = schedule( batch_size=num_reqs, cu_query_lens=query_start_loc, @@ -375,6 +422,8 @@ def schedule( seq_lens=seq_lens, block_table=block_table_tensor, slot_mapping=slot_mapping, + max_dcp_context_kv_len=max_dcp_context_kv_len, + dcp_context_kv_lens=dcp_context_kv_lens, use_cascade=use_cascade, common_prefix_len=common_prefix_len, scheduler_metadata=scheduler_metadata, @@ -392,19 +441,21 @@ def use_cascade_attention(self, *args, **kwargs) -> bool: class FlashAttentionImpl(AttentionImpl): + can_return_lse_for_decode: bool = True + def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - sinks: Optional[torch.Tensor] = None, + kv_sharing_target_layer_name: str | None = None, + sinks: torch.Tensor | None = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -432,6 +483,9 @@ def __init__( self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() + # Cache the batch invariant result for use in forward passes + self.batch_invariant_enabled = vllm_is_batch_invariant() + if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8(): raise NotImplementedError( "FlashAttention does not support fp8 kv-cache on this device." @@ -447,6 +501,9 @@ def __init__( "heads in the layer" ) + def supports_quant_query_input(self) -> bool: + return True + def forward( self, layer: torch.nn.Module, @@ -455,9 +512,9 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -483,7 +540,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) attn_type = self.attn_type @@ -558,30 +615,45 @@ def forward( descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) - flash_attn_varlen_func( - q=query[:num_actual_tokens], - k=key_cache, - v=value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - seqused_k=seqused_k, - max_seqlen_k=max_seqlen_k, - softmax_scale=self.scale, - causal=attn_metadata.causal, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - block_table=block_table, - softcap=self.logits_soft_cap, - scheduler_metadata=scheduler_metadata, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - num_splits=attn_metadata.max_num_splits, - s_aux=self.sinks, - ) - return output + if self.dcp_world_size > 1: + self._forward_with_dcp( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + key_cache, + value_cache, + output[:num_actual_tokens], + attn_metadata, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + return output + else: + flash_attn_varlen_func( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + scheduler_metadata=scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + num_splits=attn_metadata.max_num_splits, + s_aux=self.sinks, + ) + return output # Cascade attention (rare case). cascade_attention( @@ -611,6 +683,86 @@ def forward( ) return output + def _forward_with_dcp( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + output: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + q_descale: torch.Tensor | None = None, + k_descale: torch.Tensor | None = None, + v_descale: torch.Tensor | None = None, + ) -> torch.Tensor: + cu_seqlens_q = attn_metadata.query_start_loc + max_seqlen_q = attn_metadata.max_query_len + block_table = attn_metadata.block_table + + query = query.contiguous() + query_across_dcp = get_dcp_group().all_gather(query, dim=1) + context_attn_out, context_lse = flash_attn_varlen_func( + q=query_across_dcp, + k=key_cache, + v=value_cache, + out=None, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=attn_metadata.dcp_context_kv_lens, + max_seqlen_k=attn_metadata.max_dcp_context_kv_len, + softmax_scale=self.scale, + causal=False, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + return_softmax_lse=True, + scheduler_metadata=attn_metadata.scheduler_metadata, + fa_version=self.vllm_flash_attn_version, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + # FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ] + context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs( + context_attn_out, + context_lse.transpose(0, 1), + get_dcp_group(), + return_lse=True, + ) + context_lse_cor = context_lse_cor.transpose(0, 1).contiguous() + + query_attn_out, query_lse = flash_attn_varlen_func( + q=query, + k=key, + v=value, + out=None, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_k=max_seqlen_q, + softmax_scale=self.scale, + causal=attn_metadata.causal, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + softcap=self.logits_soft_cap, + return_softmax_lse=True, + fa_version=self.vllm_flash_attn_version, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + assert context_attn_out_cor.shape == query_attn_out.shape + assert context_lse_cor.shape == query_lse.shape + merge_attn_states( + output, + context_attn_out_cor, + context_lse_cor, + query_attn_out, + query_lse, + ) + def _forward_encoder_attention( self, query: torch.Tensor, @@ -666,6 +818,7 @@ def _forward_encoder_attention( q_descale=layer._q_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), + num_splits=1 if self.batch_invariant_enabled else 0, ) return output @@ -680,6 +833,7 @@ def use_cascade_attention( use_sliding_window: bool, use_local_attention: bool, num_sms: int, + dcp_world_size: int, ) -> bool: """Decide whether to use cascade attention. @@ -701,6 +855,9 @@ def use_cascade_attention( num_reqs = len(query_lens) if num_reqs < 8: return False + # disable cascade attention for DCP + if dcp_world_size > 1: + return False # Heuristics to decide whether using cascade attention is beneficial. # 1. When FlashDecoding is not used for normal attention, cascade attention @@ -757,18 +914,18 @@ def cascade_attention( suffix_kv_lens: torch.Tensor, max_kv_len: int, softmax_scale: float, - alibi_slopes: Optional[torch.Tensor], + alibi_slopes: torch.Tensor | None, sliding_window: tuple[int, int], logits_soft_cap: float, block_table: torch.Tensor, common_prefix_len: int, fa_version: int, - prefix_scheduler_metadata: Optional[torch.Tensor] = None, - suffix_scheduler_metadata: Optional[torch.Tensor] = None, - q_descale: Optional[torch.Tensor] = None, - k_descale: Optional[torch.Tensor] = None, - v_descale: Optional[torch.Tensor] = None, - s_aux: Optional[torch.Tensor] = None, + prefix_scheduler_metadata: torch.Tensor | None = None, + suffix_scheduler_metadata: torch.Tensor | None = None, + q_descale: torch.Tensor | None = None, + k_descale: torch.Tensor | None = None, + v_descale: torch.Tensor | None = None, + s_aux: torch.Tensor | None = None, ) -> torch.Tensor: assert alibi_slopes is None, "Cascade attention does not support ALiBi." # TODO: Support sliding window. @@ -806,6 +963,7 @@ def cascade_attention( # s_aux is incorporated into prefix_lse inside the GPU kernel, # enabling its effect during the final attention merge. s_aux=s_aux, + num_splits=1 if vllm_is_batch_invariant() else 0, ) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) @@ -830,6 +988,7 @@ def cascade_attention( q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, + num_splits=1 if vllm_is_batch_invariant() else 0, ) # Merge prefix and suffix outputs, and store the result in output. diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 38cf0ca56733..029293d2f6dd 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -2,10 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashInfer.""" -from __future__ import annotations - from dataclasses import dataclass -from typing import ClassVar, Union +from typing import ClassVar import numpy as np import torch @@ -18,14 +16,17 @@ from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.utils import FP4Tensor -from vllm import _custom_ops as ops from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, AttentionType, + MultipleOf, ) from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, @@ -33,13 +34,13 @@ ) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import cdiv, is_pin_memory_available +from vllm.utils import cdiv from vllm.utils.flashinfer import ( can_use_trtllm_attention, flashinfer_disable_q_quantization, - supports_trtllm_attention, use_trtllm_attention, ) +from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -52,6 +53,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024 FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 @@ -166,6 +168,13 @@ def get_supported_head_sizes(cls) -> list[int]: # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 return [64, 128, 256] + @staticmethod + def get_supported_kernel_block_size() -> list[int | MultipleOf]: + # Note: Not sure for all platforms, + # but on Blackwell, only support a page size of + # 16, 32, 64 + return [16, 32, 64] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() @@ -183,15 +192,15 @@ def get_name() -> str: return "FLASHINFER" @staticmethod - def get_impl_cls() -> type[FlashInferImpl]: + def get_impl_cls() -> type["FlashInferImpl"]: return FlashInferImpl @staticmethod - def get_metadata_cls() -> type[FlashInferMetadata]: + def get_metadata_cls() -> type["FlashInferMetadata"]: return FlashInferMetadata @staticmethod - def get_builder_cls() -> type[FlashInferMetadataBuilder]: + def get_builder_cls() -> type["FlashInferMetadataBuilder"]: return FlashInferMetadataBuilder @staticmethod @@ -283,12 +292,27 @@ def __init__( self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) + if vllm_is_batch_invariant(): + self.decode_fixed_split_size = 2048 + self.prefill_fixed_split_size = 4096 + self.disable_split_kv = True + else: + self.decode_fixed_split_size = -1 + self.prefill_fixed_split_size = -1 + self.disable_split_kv = False + self.compilation_config = vllm_config.compilation_config max_num_pages_per_req = cdiv( self.model_config.max_model_len, self.kv_cache_spec.block_size ) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req + speculative_config = vllm_config.speculative_config + num_spec_tokens = ( + speculative_config.num_speculative_tokens + if speculative_config is not None + else 0 + ) self.enable_cuda_graph = ( self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL ) @@ -299,7 +323,8 @@ def __init__( int, BatchDecodeWithPagedKVCacheWrapper ] = {} self._decode_cudagraph_max_bs = min( - max_num_reqs, self.compilation_config.max_capture_size + (1 + num_spec_tokens) * max_num_reqs, + self.compilation_config.max_cudagraph_capture_size, ) self.num_qo_heads = self.model_config.get_num_attention_heads( @@ -323,15 +348,13 @@ def __init__( # VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to # use fp8 q if kv cache is fp8, and will fall back to model dtype # if TRTLLM attention kernel is not used when building attn metadata - if supports_trtllm_attention() and not flashinfer_disable_q_quantization(): + can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) + if can_use_trtllm and not flashinfer_disable_q_quantization(): self.q_data_type = self.kv_cache_dtype else: self.q_data_type = self.model_config.dtype - supports_spec_as_decode = can_use_trtllm_attention( - self.num_qo_heads, self.num_kv_heads - ) - self._init_reorder_batch_threshold(1, supports_spec_as_decode) + self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm) self._cascade_wrapper = None # Wrapper for cascade attention @@ -344,7 +367,7 @@ def __init__( self.window_left = self.global_hyperparameters.window_left self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap self.has_sinks = self.global_hyperparameters.has_sinks - if self.has_sinks and not supports_trtllm_attention(): + if self.has_sinks and not can_use_trtllm: raise NotImplementedError( "FlashInfer backend currently does not support attention " "sinks, please use trtllm on blackwell or flash attention on " @@ -381,8 +404,11 @@ def __init__( def _get_workspace_buffer(self): if self._workspace_buffer is None: + buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE + if vllm_is_batch_invariant(): + buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT self._workspace_buffer = torch.zeros( - FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=self.device + buffer_size, dtype=torch.uint8, device=self.device ) return self._workspace_buffer @@ -548,16 +574,30 @@ def build( has_sinks=self.has_sinks, has_spec=uses_spec_reorder, ) - if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm): - raise NotImplementedError( - "FlashInfer backend currently does not support attention " - "sinks, please use trtllm on blackwell or flash attention on " - "earlier GPUs." - ) - # If TRTLLM attention is not used, the q quantization is not supported. - # Fall back to use model dtype. if not (prefill_use_trtllm and decode_use_trtllm): + if self.has_sinks: + raise NotImplementedError( + "FlashInfer backend currently does not support attention " + "sinks, please use trtllm on blackwell or flash attention " + "on earlier GPUs." + ) + + if not self.global_hyperparameters.has_same_window_lefts: + raise ValueError( + "Window left is not the same for all layers. " + "One potential fix is to set disable_sliding_window=True" + ) + + assert self.global_hyperparameters.has_same_all_params, ( + "FlashInfer backend currently only supports models in which " + "all layers share the same values for the following " + "hyperparameters: `window_left`, `logits_soft_cap`, " + "`sm_scale`." + ) + + # The q quantization is not supported for non-trtllm attention, + # fall back to model dtype. self.q_data_type = self.model_config.dtype attn_metadata = FlashInferMetadata( @@ -645,6 +685,8 @@ def build( logits_soft_cap=self.logits_soft_cap, q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype, + fixed_split_size=self.prefill_fixed_split_size, + disable_split_kv=self.disable_split_kv, ) else: attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to( @@ -660,7 +702,7 @@ def build( use_cudagraph = ( self.enable_cuda_graph and pure_decode - and num_decodes <= self._decode_cudagraph_max_bs + and num_decode_tokens <= self._decode_cudagraph_max_bs ) if use_cudagraph: num_input_tokens = self.vllm_config.pad_for_cudagraph( @@ -706,6 +748,8 @@ def build( logits_soft_cap=self.logits_soft_cap, q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype, + fixed_split_size=self.decode_fixed_split_size, + disable_split_kv=self.disable_split_kv, ) return attn_metadata @@ -772,9 +816,7 @@ def __init__( ) self.sinks = sinks - self.support_trtllm_attn = ( - supports_trtllm_attention() and num_heads % num_kv_heads == 0 - ) + self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads) self.bmm1_scale: float | None = None self.bmm2_scale: float | None = None self.o_sf_scale: float | None = None @@ -786,6 +828,17 @@ def fused_output_quant_supported(self, quant_key: QuantKey): and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) ) + def supports_quant_query_input(self) -> bool: + if flashinfer_disable_q_quantization(): + return False + + return self.support_trtllm_attn + + # FlashInfer requires attention sinks to be float32 + def process_weights_after_loading(self, act_dtype: torch.dtype): + if self.sinks is not None and self.sinks.dtype != torch.float32: + self.sinks = self.sinks.to(torch.float32) + def forward( self, layer: torch.nn.Module, @@ -815,7 +868,13 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) + + # Ensure query dtype matches the expected dtype from attention metadata + assert attn_metadata.q_data_type == query.dtype, ( + f"Query dtype mismatch: expected {attn_metadata.q_data_type}, " + f"got {query.dtype}" + ) if self.bmm1_scale is None: self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale @@ -857,15 +916,6 @@ def forward( elif output.dtype == FP4_DTYPE: self.o_sf_scale = layer._o_scale_float - # Insert FP8 quant for query - if attn_metadata.q_data_type == FP8_DTYPE: - num_tokens, num_heads, head_size = query.shape - query, _ = ops.scaled_fp8_quant( - query.reshape((num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale, - ) - query = query.reshape((num_tokens, num_heads, head_size)) - # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead @@ -1092,13 +1142,15 @@ def fast_plan_decode( pos_encoding_mode: str = "NONE", window_left: int = -1, logits_soft_cap: float | None = None, - q_data_type: Union[str, torch.dtype] | None = "float16", - kv_data_type: Union[str, torch.dtype] | None = None, - data_type: Union[str, torch.dtype] | None = None, + q_data_type: str | torch.dtype | None = "float16", + kv_data_type: str | torch.dtype | None = None, + data_type: str | torch.dtype | None = None, sm_scale: float | None = None, rope_scale: float | None = None, rope_theta: float | None = None, non_blocking: bool = True, + fixed_split_size: int = -1, + disable_split_kv: bool = False, ) -> None: """ A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for @@ -1135,6 +1187,10 @@ def fast_plan_decode( rope_scale, rope_theta, non_blocking, + None, # block_tables + None, # seq_lens + fixed_split_size, + disable_split_kv, ) self.vllm_first_call = False return @@ -1182,7 +1238,7 @@ def fast_plan_decode( qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") try: - # Make sure we pass exactly 15 arguments for tensor core version + # Make sure we pass exactly 18 arguments for tensor core version self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, @@ -1199,6 +1255,9 @@ def fast_plan_decode( head_dim, head_dim, False, # causal + window_left, + fixed_split_size, + disable_split_kv, ) except Exception as e: raise RuntimeError(f"Error in tensor core plan: {e}") from e diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 7775445ae773..ffea14ec63f8 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -3,7 +3,6 @@ """Attention layer with FlexAttention.""" from dataclasses import dataclass -from typing import Optional, Union import torch import torch._dynamo.decorators @@ -27,9 +26,10 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant, + vllm_is_batch_invariant, ) -from vllm.utils import cdiv, is_torch_equal_or_newer +from vllm.utils import cdiv +from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, @@ -282,9 +282,9 @@ class FlexAttentionMetadata: use_cascade: bool common_prefix_len: int - cu_prefix_query_lens: Optional[torch.Tensor] - prefix_kv_lens: Optional[torch.Tensor] - suffix_kv_lens: Optional[torch.Tensor] + cu_prefix_query_lens: torch.Tensor | None + prefix_kv_lens: torch.Tensor | None + suffix_kv_lens: torch.Tensor | None # Block info total_cache_tokens: int @@ -300,15 +300,15 @@ class FlexAttentionMetadata: # Flex Metadata num_blocks = 0 - block_mask: Optional[BlockMask] = None - score_mod: Optional[_score_mod_signature] = None + block_mask: BlockMask | None = None + score_mod: _score_mod_signature | None = None logical_mask_mod: _mask_mod_signature = causal_mask_mod - doc_ids: Optional[torch.Tensor] = None + doc_ids: torch.Tensor | None = None direct_build: bool = True q_block_size: int = 16 kv_block_size: int = 16 - transformed_score_mod: Optional[_score_mod_signature] = None - sliding_window: Optional[int] = None + transformed_score_mod: _score_mod_signature | None = None + sliding_window: int | None = None def _convert_physical_to_logical( self, @@ -443,7 +443,7 @@ def get_mask_mod(self): mask_mod = and_masks(mask_mod, sliding_window_mask_mod) return mask_mod - def get_transformed_score_mod(self) -> Optional[_score_mod_signature]: + def get_transformed_score_mod(self) -> _score_mod_signature | None: """Creates the transformed score_mod function for FlexAttention. This function wraps the user's score_mod to handle physical-to-logical @@ -658,7 +658,10 @@ def build( total_cache_tokens=total_cache_tokens, decode_offset=offset_tensor, num_blocks_per_seq=num_blocks_per_seq, - direct_build=self.direct_build, + # FIXME(Isotr0py): direct build has issue to build bidirectional + # attention block mask for encoder-only models, disable it temporarily. + # see: https://github.com/vllm-project/vllm/pull/27329#issuecomment-3431484053 + direct_build=(self.direct_build and common_attn_metadata.causal), q_block_size=self.q_block_size, kv_block_size=self.kv_block_size, ) @@ -669,9 +672,9 @@ def use_cascade_attention(self, *args, **kwargs) -> bool: class FlexAttentionImpl(AttentionImpl): - sliding_window: Optional[int] - alibi_slopes: Optional[torch.Tensor] - logits_soft_cap: Optional[float] + sliding_window: int | None + alibi_slopes: torch.Tensor | None + logits_soft_cap: float | None def __init__( self, @@ -679,12 +682,12 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, + kv_sharing_target_layer_name: str | None = None, **kwargs, ) -> None: self.num_heads = num_heads @@ -742,9 +745,9 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlexAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with FLexAttention. @@ -768,7 +771,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) # query = self.view_as_4d(query).permute(0, 2, 1, 3) # return torch.empty_like(query) @@ -860,11 +863,11 @@ def forward( def get_kernel_options( query, block_m, block_n, use_direct_build: bool -) -> dict[str, Union[int, bool]]: - kernel_options: dict[str, Union[int, bool]] = { +) -> dict[str, int | bool]: + kernel_options: dict[str, int | bool] = { "FORCE_USE_FLEX_ATTENTION": True, } - if vllm_kernel_override_batch_invariant(): + if vllm_is_batch_invariant(): kernel_options["BLOCK_M"] = 16 kernel_options["BLOCK_N"] = 16 kernel_options["IS_DIVISIBLE"] = False diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 21fc2ab72768..2ca19646911e 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -3,7 +3,6 @@ """Backend for GatedDeltaNet attention.""" from dataclasses import dataclass -from typing import Optional import torch @@ -36,29 +35,27 @@ class GDNAttentionMetadata: num_spec_decode_tokens: int num_actual_tokens: int - has_initial_state: Optional[torch.Tensor] = None + has_initial_state: torch.Tensor | None = None - spec_query_start_loc: Optional[torch.Tensor] = ( - None # shape: [num_spec_decodes + 1,] - ) - non_spec_query_start_loc: Optional[torch.Tensor] = ( + spec_query_start_loc: torch.Tensor | None = None # shape: [num_spec_decodes + 1,] + non_spec_query_start_loc: torch.Tensor | None = ( None # shape: [batch - num_spec_decodes + 1,] ) - spec_state_indices_tensor: Optional[torch.Tensor] = None # shape: [batch, num_spec] - non_spec_state_indices_tensor: Optional[torch.Tensor] = ( + spec_state_indices_tensor: torch.Tensor | None = None # shape: [batch, num_spec] + non_spec_state_indices_tensor: torch.Tensor | None = ( None # shape: [batch - num_spec_decodes,] ) - spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,] - spec_token_masks: Optional[torch.Tensor] = ( - None # shape: [num_prefill_tokens + num_decode_tokens,] - ) - num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,] + spec_sequence_masks: torch.Tensor | None = None # shape: [batch,] + spec_token_indx: torch.Tensor | None = None + non_spec_token_indx: torch.Tensor | None = None + + num_accepted_tokens: torch.Tensor | None = None # shape: [batch,] # The following attributes are for triton implementation of causal_conv1d - nums_dict: Optional[dict] = None - batch_ptr: Optional[torch.Tensor] = None - token_chunk_offset_ptr: Optional[torch.Tensor] = None + nums_dict: dict | None = None + batch_ptr: torch.Tensor | None = None + token_chunk_offset_ptr: torch.Tensor | None = None class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]): @@ -90,7 +87,7 @@ def __init__( ) self.decode_cudagraph_max_bs = min( self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1), - self.compilation_config.max_capture_size, + self.compilation_config.max_cudagraph_capture_size, ) self.spec_state_indices_tensor = torch.empty( @@ -108,9 +105,14 @@ def __init__( dtype=torch.bool, device=device, ) - self.spec_token_masks = torch.empty( + self.spec_token_indx = torch.empty( (self.decode_cudagraph_max_bs * (self.num_spec + 1),), - dtype=torch.bool, + dtype=torch.int32, + device=device, + ) + self.non_spec_token_indx = torch.empty( + (self.decode_cudagraph_max_bs * (self.num_spec + 1),), + dtype=torch.int32, device=device, ) self.spec_query_start_loc = torch.empty( @@ -133,8 +135,8 @@ def build( # type: ignore[override] self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, - num_accepted_tokens: Optional[torch.Tensor] = None, - num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None, + num_accepted_tokens: torch.Tensor | None = None, + num_decode_draft_tokens_cpu: torch.Tensor | None = None, fast_build: bool = False, ) -> GDNAttentionMetadata: m = common_attn_metadata @@ -169,7 +171,8 @@ def build( # type: ignore[override] split_decodes_and_prefills(m, decode_threshold=1) ) num_spec_decode_tokens = 0 - spec_token_masks = None + spec_token_indx = None + non_spec_token_indx = None spec_state_indices_tensor = None non_spec_state_indices_tensor = m.block_table_tensor[:, 0] spec_query_start_loc = None @@ -183,18 +186,23 @@ def build( # type: ignore[override] num_prefills = non_spec_query_lens.size(0) - num_decodes num_decode_tokens = num_decodes num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens + num_spec_decode_tokens = ( + query_lens.sum().item() - num_prefill_tokens - num_decode_tokens + ) if num_prefills == 0 and num_decodes == 0: - spec_token_masks = torch.ones( - ( - min( - num_spec_decodes * (self.num_spec + 1), - query_start_loc[-1].item(), - ) - ), - dtype=torch.bool, + spec_token_size = min( + num_spec_decodes * (self.num_spec + 1), + query_start_loc[-1].item(), + ) + spec_token_indx = torch.arange( + spec_token_size, + dtype=torch.int32, device=query_start_loc.device, ) + non_spec_token_indx = torch.empty( + 0, dtype=torch.int32, device=query_start_loc.device + ) spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1] non_spec_state_indices_tensor = None spec_query_start_loc = query_start_loc @@ -203,6 +211,11 @@ def build( # type: ignore[override] spec_token_masks = torch.repeat_interleave( spec_sequence_masks, query_lens ) + index = torch.argsort(spec_token_masks) + num_non_spec_tokens = num_prefill_tokens + num_decode_tokens + non_spec_token_indx = index[:num_non_spec_tokens] + spec_token_indx = index[num_non_spec_tokens:] + spec_state_indices_tensor = m.block_table_tensor[ spec_sequence_masks, : self.num_spec + 1 ] @@ -229,9 +242,6 @@ def build( # type: ignore[override] out=non_spec_query_start_loc[1:], ) - num_spec_decode_tokens = ( - query_lens.sum().item() - num_prefill_tokens - num_decode_tokens - ) assert num_accepted_tokens is not None num_accepted_tokens = num_accepted_tokens[spec_sequence_masks] @@ -277,12 +287,18 @@ def build( # type: ignore[override] spec_sequence_masks = self.spec_sequence_masks[:batch_size] spec_sequence_masks[num_spec_decodes:].fill_(False) - assert spec_token_masks is not None - self.spec_token_masks[: spec_token_masks.size(0)].copy_( - spec_token_masks, non_blocking=True + assert non_spec_token_indx is not None and spec_token_indx is not None + self.non_spec_token_indx[: non_spec_token_indx.size(0)].copy_( + non_spec_token_indx, non_blocking=True + ) + non_spec_token_indx = self.non_spec_token_indx[ + : non_spec_token_indx.size(0) + ] + + self.spec_token_indx[: spec_token_indx.size(0)].copy_( + spec_token_indx, non_blocking=True ) - spec_token_masks = self.spec_token_masks[:num_actual_tokens] - spec_token_masks[spec_token_masks.size(0) :].fill_(False) + spec_token_indx = self.spec_token_indx[: spec_token_indx.size(0)] self.spec_query_start_loc[: num_spec_decodes + 1].copy_( spec_query_start_loc, non_blocking=True @@ -335,7 +351,8 @@ def build( # type: ignore[override] spec_state_indices_tensor=spec_state_indices_tensor, non_spec_state_indices_tensor=non_spec_state_indices_tensor, spec_sequence_masks=spec_sequence_masks, - spec_token_masks=spec_token_masks, + spec_token_indx=spec_token_indx, + non_spec_token_indx=non_spec_token_indx, num_accepted_tokens=num_accepted_tokens, nums_dict=nums_dict, batch_ptr=batch_ptr, diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index e305cb2d8702..30c63e0ded8e 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional import torch @@ -26,7 +25,7 @@ class Mamba1AttentionMetadata: query_start_loc: torch.Tensor context_lens_tensor: torch.Tensor state_indices_tensor: torch.Tensor - has_initial_states: Optional[torch.Tensor] + has_initial_states: torch.Tensor | None num_prefills: int num_prefill_tokens: int num_decodes: int diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 10f09442d82e..7ca8501a8a6f 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools from dataclasses import dataclass -from typing import Optional import torch @@ -108,18 +107,18 @@ class Mamba2AttentionMetadata: # The following tensors only contain prefill requests and will be None if # the batch has no prefill request. - has_initial_states_p: Optional[torch.Tensor] - seq_idx_p: Optional[torch.Tensor] + has_initial_states_p: torch.Tensor | None + seq_idx_p: torch.Tensor | None # cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for # each chunk, its offests into the varlen sequence dimension. It is defined # such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to # cu_chunk_seqlen_p[i+1]. - cu_chunk_seqlen_p: Optional[torch.Tensor] + cu_chunk_seqlen_p: torch.Tensor | None # last_chunk_indices_p is a tensor of shape (batch,) that contains the # index of the last chunk for every sequence in the (prefill) batch. - last_chunk_indices_p: Optional[torch.Tensor] + last_chunk_indices_p: torch.Tensor | None state_indices_tensor: torch.Tensor # shape: [batch,] block_idx_last_scheduled_token: torch.Tensor # shape: [batch,] @@ -128,9 +127,9 @@ class Mamba2AttentionMetadata: num_computed_tokens_p: torch.Tensor # shape: [batch,] # The following attributes are for triton implementation of causal_conv1d - nums_dict: Optional[dict] = None - batch_ptr: Optional[torch.Tensor] = None - token_chunk_offset_ptr: Optional[torch.Tensor] = None + nums_dict: dict | None = None + batch_ptr: torch.Tensor | None = None + token_chunk_offset_ptr: torch.Tensor | None = None class Mamba2AttentionMetadataBuilder( diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 5aafb9813df0..52f26a9e61ca 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -36,7 +36,7 @@ def __init__( self.compilation_config = vllm_config.compilation_config self.decode_cudagraph_max_bs = min( self.vllm_config.scheduler_config.max_num_seqs, - self.compilation_config.max_capture_size, + self.compilation_config.max_cudagraph_capture_size, ) self.state_indices_tensor = torch.empty( (self.decode_cudagraph_max_bs,), diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 3fb00f5917ea..b920fd929e85 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -190,7 +190,8 @@ import functools from abc import abstractmethod from dataclasses import dataclass, field -from typing import ClassVar, Generic, Optional, TypeVar, Union +from enum import Enum +from typing import ClassVar, Generic, TypeVar import torch from tqdm import tqdm @@ -210,6 +211,9 @@ from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) from vllm.model_executor.layers.linear import ( ColumnParallelLinear, LinearBase, @@ -227,6 +231,24 @@ ) from vllm.v1.kv_cache_interface import AttentionSpec + +class QueryLenSupport(Enum): + """Defines the level of query length support for an attention backend's + decode pipeline. + + - SINGLE_ONLY: Decode pipeline only supports single-token queries + (query_len=1) + - UNIFORM: Decode pipeline supports uniform multi-token queries + (all requests must have same query_len > 1) + - VARLEN: Decode pipeline supports variable-length queries + (mixed query lengths in same batch) + """ + + SINGLE_ONLY = "single_only" + UNIFORM = "uniform" + VARLEN = "varlen" + + try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -243,6 +265,8 @@ flashinfer_available = True except ImportError: + BatchPrefillWithRaggedKVCacheWrapper = object + flashinfer_available = False @@ -337,22 +361,23 @@ class ChunkedContextMetadata: workspace: torch.Tensor # for mla DCP - cp_chunk_seq_lens: Optional[list[list[int]]] = None - origin_context_lens: Optional[list[int]] = None - cp_cu_seq_lens: Optional[torch.Tensor] = None - chunk_size: Optional[int] = None - cu_seq_lens_lst: Optional[list[list[int]]] = None + cp_chunk_seq_lens: list[list[int]] | None = None + origin_context_lens: list[int] | None = None + cp_cu_seq_lens: torch.Tensor | None = None + chunk_size: int | None = None + cu_seq_lens_lst: list[list[int]] | None = None block_table: torch.Tensor query_start_loc: torch.Tensor max_query_len: int - chunked_context: Optional[ChunkedContextMetadata] = None + chunked_context: ChunkedContextMetadata | None = None + query_seq_lens: torch.Tensor | None = None @dataclass class FlashInferPrefillMetadata(MLACommonPrefillMetadata): - prefill_main: Optional["BatchPrefillWithRaggedKVCacheWrapper"] = None - prefill_chunks: list["BatchPrefillWithRaggedKVCacheWrapper"] = field( + prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None + prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = field( default_factory=list ) @@ -362,14 +387,14 @@ class CudnnPrefillMetadata(MLACommonPrefillMetadata): class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata): seq_lens: torch.Tensor - query_seq_lens: Optional[torch.Tensor] = None - cudnn_workspace: Optional[torch.Tensor] = None + cudnn_workspace: torch.Tensor | None = None @dataclass class MLACommonDecodeMetadata: block_table: torch.Tensor seq_lens: torch.Tensor + dcp_tot_seq_lens: torch.Tensor | None D = TypeVar("D", bound=MLACommonDecodeMetadata) @@ -406,12 +431,15 @@ class MLACommonMetadata(Generic[D]): num_prefills: int # The dimension of the attention heads - head_dim: Optional[int] = None + head_dim: int | None = None - decode: Optional[D] = None - prefill: Optional[ - Union[MLACommonPrefillMetadata, FlashInferPrefillMetadata, CudnnPrefillMetadata] - ] = None + decode: D | None = None + prefill: ( + MLACommonPrefillMetadata + | FlashInferPrefillMetadata + | CudnnPrefillMetadata + | None + ) = None def __post_init__(self): if self.head_dim is not None: @@ -429,6 +457,7 @@ def use_flashinfer_prefill() -> bool: not envs.VLLM_DISABLE_FLASHINFER_PREFILL and flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL + and not envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL and current_platform.is_device_capability(100) ) @@ -442,6 +471,15 @@ def use_cudnn_prefill() -> bool: ) +def use_trtllm_ragged_deepseek_prefill() -> bool: + """Check if TRT-LLM ragged DeepSeek prefill should be used.""" + return ( + flashinfer_available + and envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL + and current_platform.is_device_capability(100) + ) + + # Currently 394MB, this can be tuned based on GEMM sizes used. # Chosen to be the same as sglang: # https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37 @@ -454,19 +492,18 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): understand this class """ - # Whether the backend supports reordering the batch such that - # short sequences (i.e. verification for speculative decoding) are - # classified as decode requests. - # If True, this will increase `reorder_batch_threshold` (below) when - # speculative decoding is enabled, and set `require_uniform=True` when - # when reordering the batch. Non-uniform decode requests will - # fall back to prefill in this case. - supports_uniform_spec_as_decode: ClassVar[bool] = False + # Defines the level of query length support for this backend. + # - SINGLE_ONLY: Only single-token queries (no spec decode support) + # - UNIFORM: Supports uniform multi-token queries (spec decode with uniform lengths) + # - VARLEN: Supports variable-length queries (spec decode with mixed lengths) + # If set to UNIFORM or VARLEN, this will increase `reorder_batch_threshold` when + # speculative decoding is enabled. + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.SINGLE_ONLY # The threshold for reordering the batch into decode and prefill requests. # If > 1, the batch will be reordered such that requests with # query length <= threshold are classified as decode requests. - # Use `supports_uniform_spec_as_decode` (above) to set this automatically + # Use `query_len_support` (above) to set this automatically # when speculative decoding is enabled. reorder_batch_threshold: int = 1 @@ -507,7 +544,7 @@ def __init__( layer_names: list[str], vllm_config: VllmConfig, device: torch.device, - metadata_cls: Optional[type[M]] = None, + metadata_cls: type[M] | None = None, ): self.metadata_cls = ( metadata_cls if metadata_cls is not None else MLACommonMetadata @@ -566,6 +603,7 @@ def __init__( self._use_cudnn_prefill = use_cudnn_prefill() self._use_fi_prefill = use_flashinfer_prefill() + self._use_trtllm_ragged_prefill = use_trtllm_ragged_deepseek_prefill() self.prefill_metadata_cls = ( FlashInferPrefillMetadata if self._use_fi_prefill @@ -579,13 +617,18 @@ def __init__( FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device ) - self._fi_prefill_main: Optional[BatchPrefillWithRaggedKVCacheWrapper] = None + self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = [] self._global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl) ) + if self._use_trtllm_ragged_prefill: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device + ) + if self._use_cudnn_prefill: self.cudnn_workspace = torch.empty( CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs, @@ -593,11 +636,18 @@ def __init__( device=device, ) - supports_spec_as_decode = self.supports_uniform_spec_as_decode + supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY self._init_reorder_batch_threshold( - self.reorder_batch_threshold, supports_spec_as_decode + self.reorder_batch_threshold, supports_spec_decode ) + # Validate consistency between query_len_support and reorder_batch_threshold + if self.query_len_support == QueryLenSupport.SINGLE_ONLY: + assert self.reorder_batch_threshold == 1, ( + f"reorder_batch_threshold must be 1 when query_len_support is " + f"SINGLE_ONLY, got {self.reorder_batch_threshold}" + ) + def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): qo_indptr = prefill.query_start_loc @@ -682,10 +732,12 @@ def _build_decode( query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, + dcp_tot_seq_lens_device: torch.Tensor | None, ) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, ) def build_for_cudagraph_capture( @@ -727,6 +779,7 @@ def build( query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu + dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] @@ -736,13 +789,16 @@ def build( split_decodes_and_prefills( common_attn_metadata, decode_threshold=self.reorder_batch_threshold, - require_uniform=self.supports_uniform_spec_as_decode, + require_uniform=(self.query_len_support != QueryLenSupport.VARLEN), ) ) # Note(hc): update seq_lens of decode reqs under DCP. if self.dcp_world_size > 1: - seq_lens[:num_decodes] = seq_lens[:num_decodes] // self.dcp_world_size + ( + assert dcp_local_seq_lens is not None + dcp_local_seq_lens[:num_decodes] = seq_lens[ + :num_decodes + ] // self.dcp_world_size + ( self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size ) @@ -894,15 +950,25 @@ def build( ) prefill_metadata.cudnn_workspace = self.cudnn_workspace + if self._use_trtllm_ragged_prefill: + prefill_metadata.query_seq_lens = ( + prefill_query_start_loc[1:] - prefill_query_start_loc[:-1] + ) + decode_metadata = None if num_decodes > 0: decode_metadata = self._build_decode( block_table_tensor=block_table_tensor[:num_decodes, ...], seq_lens_cpu=seq_lens_cpu[:num_decodes], - seq_lens_device=seq_lens[:num_decodes], + seq_lens_device=dcp_local_seq_lens[:num_decodes] + if self.dcp_world_size > 1 and dcp_local_seq_lens is not None + else seq_lens[:num_decodes], query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1], query_start_loc_device=query_start_loc[: num_decodes + 1], num_decode_tokens=num_decode_tokens, + dcp_tot_seq_lens_device=seq_lens[:num_decodes] + if self.dcp_world_size > 1 + else None, ) attn_metadata = self.metadata_cls( @@ -1011,14 +1077,14 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float], + logits_soft_cap: float | None, attn_type: str, - kv_sharing_target_layer_name: Optional[str], + kv_sharing_target_layer_name: str | None, # MLA Specific Arguments - q_lora_rank: Optional[int], + q_lora_rank: int | None, kv_lora_rank: int, qk_nope_head_dim: int, qk_rope_head_dim: int, @@ -1026,7 +1092,7 @@ def __init__( v_head_dim: int, kv_b_proj: ColumnParallelLinear, indexer=None, - q_pad_num_heads: Optional[int] = None, + q_pad_num_heads: int | None = None, ) -> None: if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported for MLA") @@ -1145,6 +1211,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + if is_rocm_aiter_fp8bmm_enabled(): # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) x = aiter_triton_fp8_bmm( @@ -1184,6 +1251,13 @@ def __init__(self, *args, **kwargs) -> None: self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi self._pad_v = False + elif use_trtllm_ragged_deepseek_prefill(): + logger.debug_once("Using TRT-LLM ragged DeepSeek prefill for MLA") + self._run_prefill_context_chunk = ( + self._run_prefill_context_chunk_trtllm_ragged + ) + self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged + self._pad_v = False elif use_cudnn_prefill(): logger.debug_once("Using CUDNN prefill for MLA") self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn @@ -1214,7 +1288,7 @@ def __init__(self, *args, **kwargs) -> None: and current_platform.get_device_capability()[0] == 9 ) - self.dcp_world_size: Optional[int] = None + self.dcp_world_size: int | None = None self.chunked_prefill_workspace_size = ( MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( @@ -1237,6 +1311,8 @@ def _flash_attn_varlen_diff_headdims( # ROCm leverages the upstream flash_attn, which takes a parameter # called "return_attn_probs" instead of return_softmax_lse kwargs["return_attn_probs"] = return_softmax_lse + if vllm_is_batch_invariant(): + kwargs["num_splits"] = 1 attn_out = self.flash_attn_varlen_func( q=q, @@ -1278,6 +1354,7 @@ def _run_prefill_new_tokens_fi( ): assert isinstance(prefill, FlashInferPrefillMetadata) assert prefill.prefill_main is not None + ret = prefill.prefill_main.run( q=q, k=k, @@ -1286,7 +1363,6 @@ def _run_prefill_new_tokens_fi( ) if isinstance(ret, tuple): - # Convert from (q_len, num_heads) to (num_heads, q_len) return ret[0], ret[1].transpose(0, 1).contiguous() return ret @@ -1336,12 +1412,14 @@ def _run_prefill_context_chunk_fi( self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v ): assert isinstance(prefill, FlashInferPrefillMetadata) + attn_out, lse = prefill.prefill_chunks[chunk_idx].run( q=q, k=k, v=v, return_lse=True, ) + # Convert from (q_len, num_heads) to (num_heads, q_len) return attn_out, lse.transpose(0, 1).contiguous() @@ -1370,6 +1448,81 @@ def _run_prefill_context_chunk_cudnn( is_cuda_graph_compatible=True, ) + def _run_prefill_new_tokens_trtllm_ragged( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): + """TRT-LLM ragged attention for new tokens (causal).""" + from flashinfer.prefill import trtllm_ragged_attention_deepseek + + assert prefill.query_seq_lens is not None + + ret = trtllm_ragged_attention_deepseek( + query=q, + key=k, + value=v, + workspace_buffer=self._workspace_buffer, + seq_lens=prefill.query_seq_lens, + max_q_len=prefill.max_query_len, + max_kv_len=prefill.max_query_len, + bmm1_scale=self.scale, + bmm2_scale=1.0, + o_sf_scale=1.0, + batch_size=prefill.query_seq_lens.shape[0], + window_left=-1, + cum_seq_lens_q=prefill.query_start_loc, + cum_seq_lens_kv=prefill.query_start_loc, + enable_pdl=False, + is_causal=True, + return_lse=return_softmax_lse, + ) + + if isinstance(ret, tuple): + # Convert from (q_len, num_heads) to (num_heads, q_len) + return ret[0], ret[1].transpose(0, 1).contiguous() + return ret + + def _run_prefill_context_chunk_trtllm_ragged( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): + """TRT-LLM ragged attention for context chunks (non-causal).""" + from flashinfer.prefill import trtllm_ragged_attention_deepseek + + assert prefill.chunked_context is not None + assert prefill.chunked_context.seq_lens[chunk_idx] is not None + + out = torch.zeros( + q.shape[0], + q.shape[1], + v.shape[2], + device=q.device, + dtype=q.dtype, + ) + self._workspace_buffer.fill_(0) + + attn_out, lse = trtllm_ragged_attention_deepseek( + query=q, + key=k, + value=v, + workspace_buffer=self._workspace_buffer, + seq_lens=prefill.chunked_context.seq_lens[chunk_idx], + max_q_len=prefill.max_query_len, + max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx], + bmm1_scale=self.scale, + bmm2_scale=1.0, + o_sf_scale=1.0, + batch_size=prefill.chunked_context.seq_lens[chunk_idx].shape[0], + window_left=-1, + cum_seq_lens_q=prefill.query_start_loc, + cum_seq_lens_kv=prefill.chunked_context.cu_seq_lens[chunk_idx], + enable_pdl=False, + is_causal=False, + return_lse=True, + out=out, + ) + + # Convert from (q_len, num_heads) to (num_heads, q_len) + return attn_out, lse.transpose(0, 1).contiguous() + def process_weights_after_loading(self, act_dtype: torch.dtype): def get_layer_weight(layer): WEIGHT_NAMES = ("weight", "qweight", "weight_packed") @@ -1698,11 +1851,11 @@ def _forward_prefill( @abstractmethod def _forward_decode( self, - q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, layer: AttentionLayer, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: raise NotImplementedError def forward( @@ -1713,9 +1866,9 @@ def forward( k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, attn_metadata: M, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." @@ -1799,9 +1952,11 @@ def forward( if has_decode: assert attn_metadata.decode is not None + decode_q_nope, decode_q_pe = decode_q.split( [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) + # Convert from (B, N, P) to (N, B, P) decode_q_nope = decode_q_nope.transpose(0, 1) @@ -1826,17 +1981,18 @@ def forward( # Pads the head_dim if necessary (for the underlying kernel) N, B, P = decode_q_nope.shape _, _, L = self.W_UK_T.shape + if self.q_pad_num_heads is not None: decode_ql_nope = decode_q_nope.new_empty( (self.q_pad_num_heads, B, L) ) decode_ql_nope.resize_((N, B, L)) - else: decode_ql_nope = decode_q_nope.new_empty((N, B, L)) # Multiply (N, B, P) x (N, P, L) -> (N, B, L) torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope) + # Convert from (N, B, L) to (B, N, L) decode_ql_nope = decode_ql_nope.transpose(0, 1) diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index a3c677ca2108..c35e238eac4c 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import ClassVar, Optional, Union +from typing import ClassVar import torch @@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import ( AttentionLayer, AttentionType, + MultipleOf, is_quantized_kv_cache, ) from vllm.logger import init_logger @@ -44,6 +45,10 @@ def get_impl_cls() -> type["CutlassMLAImpl"]: def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: return CutlassMLAMetadataBuilder + @staticmethod + def get_supported_kernel_block_size() -> list[int | MultipleOf]: + return [128] + class SM100Workspace: def __init__(self, initial_workspace_size): @@ -90,12 +95,12 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float], + logits_soft_cap: float | None, attn_type: str, - kv_sharing_target_layer_name: Optional[str], + kv_sharing_target_layer_name: str | None, # MLA Specific Arguments **mla_args, ) -> None: @@ -134,7 +139,7 @@ def __init__( # FORCE_NUM_KV_SPLITS=1 force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None) if force_num_kv_splits: - logger.warning_once("Forcing num_kv_splits to %d", int(force_num_kv_splits)) + logger.debug_once("Forcing num_kv_splits to %d", int(force_num_kv_splits)) self._num_kv_splits = int(force_num_kv_splits) else: self._num_kv_splits = -1 # => Auto-detect @@ -227,11 +232,11 @@ def _sm100_cutlass_mla_decode( def _forward_decode( self, - q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, layer: AttentionLayer, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index c0c2dbe1f961..a6aac701b784 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional, Union +from typing import ClassVar import torch @@ -17,14 +17,17 @@ get_flash_attn_version, ) from vllm.config import VllmConfig -from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonDecodeMetadata, MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder, + QueryLenSupport, ) from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec @@ -56,7 +59,7 @@ class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata): query_start_loc: torch.Tensor max_query_len: int max_seq_len: int - scheduler_metadata: Optional[torch.Tensor] = None + scheduler_metadata: torch.Tensor | None = None max_num_splits: int = 0 @@ -67,8 +70,8 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH - - reorder_batch_threshold: int = 512 + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN + reorder_batch_threshold: int = 512 # process small prefills with decode pathway def __init__( self, @@ -86,10 +89,9 @@ def __init__( self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) + self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size if self.use_full_cuda_graph and self.fa_aot_schedule: - self.max_cudagraph_size = self.compilation_config.max_capture_size - if self.max_cudagraph_size > 992: # This condition derives from FA3's internal heuristic. # TODO(woosuk): Support larger cudagraph sizes. @@ -107,21 +109,25 @@ def __init__( # pre-allocated during capture. self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH - # TODO(lucas): Until we add support for the DCP custom masking we need - # to restrict decodes to q_len == 1 when DCP is enabled. - self.reorder_batch_threshold = ( - 1 if get_dcp_group().world_size > 1 else self.reorder_batch_threshold - ) + if vllm_is_batch_invariant(): + self.max_num_splits = 1 def _schedule_decode( - self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal + self, + num_reqs, + cu_query_lens, + max_query_len, + seqlens, + max_seq_len, + causal, + max_num_splits, ): if self.fa_aot_schedule: return get_scheduler_metadata( batch_size=num_reqs, max_seqlen_q=max_query_len, max_seqlen_k=max_seq_len, - num_heads_q=self.num_heads, + num_heads_q=self.num_heads * self.dcp_world_size, num_heads_kv=1, headdim=self.mla_dims.qk_rope_head_dim, cache_seqlens=seqlens, @@ -130,7 +136,7 @@ def _schedule_decode( page_size=self.page_size, cu_seqlens_q=cu_query_lens, causal=causal, - num_splits=self.max_num_splits, + num_splits=max_num_splits, ) return None @@ -142,10 +148,20 @@ def _build_decode( query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, + dcp_tot_seq_lens_device: torch.Tensor | None, ) -> FlashAttnMLADecodeMetadata: query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] max_query_len = query_lens_cpu.max().item() - max_seq_len = seq_lens_cpu.max().item() + max_seq_len = seq_lens_device.max().item() + + # For Flash Attention MLA + full cudagraph + max_num_splits = 0 + if self.use_full_cuda_graph and num_decode_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits scheduler_metadata = self._schedule_decode( num_reqs=seq_lens_cpu.numel(), @@ -154,10 +170,9 @@ def _build_decode( seqlens=seq_lens_device, max_seq_len=max_seq_len, causal=True, + max_num_splits=max_num_splits, ) - # For FA3 + full cudagraph - max_num_splits = 0 if self.use_full_cuda_graph and scheduler_metadata is not None: n = scheduler_metadata.shape[0] # Ensure the persistent buffer is large enough @@ -173,14 +188,10 @@ def _build_decode( self.scheduler_metadata[n:] = 0 scheduler_metadata = self.scheduler_metadata[:n] - if num_decode_tokens <= self.max_cudagraph_size: - # NOTE(woosuk): Setting num_splits > 1 may increase the memory - # usage, because the intermediate buffers of size [num_splits, - # num_heads, num_tokens, head_size] are allocated. Therefore, - # we only set num_splits when using cuda graphs. - max_num_splits = self.max_num_splits + if vllm_is_batch_invariant(): + max_num_splits = 1 - return FlashAttnMLADecodeMetadata( + metadata = FlashAttnMLADecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, query_start_loc=query_start_loc_device, @@ -188,7 +199,9 @@ def _build_decode( max_seq_len=max_seq_len, scheduler_metadata=scheduler_metadata, max_num_splits=max_num_splits, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, ) + return metadata class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): @@ -200,12 +213,12 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float], + logits_soft_cap: float | None, attn_type: str, - kv_sharing_target_layer_name: Optional[str], + kv_sharing_target_layer_name: str | None, # MLA Specific Arguments **mla_args, ) -> None: @@ -247,11 +260,11 @@ def __init__( def _forward_decode( self, - q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashAttnMLAMetadata, layer: AttentionLayer, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None @@ -289,6 +302,9 @@ def _forward_decode( fa_version=3, # only version 3 is supported scheduler_metadata=attn_metadata.decode.scheduler_metadata, num_splits=attn_metadata.decode.max_num_splits, + cp_world_size=self.dcp_world_size, + cp_rank=self.dcp_rank, + cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens, ) if self.need_to_return_lse_for_decode: diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 206f96ea366a..44807c39cad3 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import ClassVar, Optional, Union +from typing import ClassVar import torch from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla @@ -13,6 +13,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder, + QueryLenSupport, ) from vllm.v1.attention.backends.utils import AttentionCGSupport @@ -22,11 +23,8 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): - # enable spec-as-decode optimization - supports_uniform_spec_as_decode: ClassVar[bool] = True - - # enable full CUDA Graph support for decode-only capture cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM class FlashInferMLABackend(MLACommonBackend): @@ -57,12 +55,12 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float], + logits_soft_cap: float | None, attn_type: str, - kv_sharing_target_layer_name: Optional[str], + kv_sharing_target_layer_name: str | None, # MLA Specific Arguments **mla_args, ) -> None: @@ -96,16 +94,16 @@ def __init__( ) self._workspace_buffer = g_fi_workspace - self.bmm1_scale: Optional[float] = None - self.bmm2_scale: Optional[float] = None + self.bmm1_scale: float | None = None + self.bmm2_scale: float | None = None def _forward_decode( self, - q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, layer: AttentionLayer, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 6ba2c682760c..1f98204031ed 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -2,11 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional, Union +from typing import ClassVar import torch -from vllm.attention.backends.abstract import AttentionLayer, AttentionType +from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf from vllm.attention.ops.flashmla import ( flash_mla_with_kvcache, get_mla_metadata, @@ -14,14 +14,22 @@ ) from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonDecodeMetadata, MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder, + QueryLenSupport, +) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + reshape_attn_output_for_spec_decode, + reshape_query_for_spec_decode, ) -from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @@ -44,6 +52,10 @@ def get_builder_cls() -> type["FlashMLAMetadataBuilder"]: def get_impl_cls() -> type["FlashMLAImpl"]: return FlashMLAImpl + @staticmethod + def get_supported_kernel_block_size() -> list[int | MultipleOf]: + return [64] + @dataclass class FlashMLADecodeMetadata(MLACommonDecodeMetadata): @@ -58,6 +70,9 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM + reorder_batch_threshold: int = 512 # process small prefills with decode pathway + # ^ TODO(matt): tune this def __init__( self, @@ -76,6 +91,7 @@ def __init__( self.cg_buf_tile_scheduler_metadata = None self.cg_buf_num_splits = None + self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8") device_properties = torch.cuda.get_device_properties(self.device) num_sms = device_properties.multi_processor_count @@ -102,11 +118,17 @@ def _build_decode( query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, + dcp_tot_seq_lens_device: torch.Tensor | None, ) -> FlashMLADecodeMetadata: + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + # we use the max but all should be the same due to uniform length requirement + max_query_len = query_lens_cpu.max().item() + num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1 tile_scheduler_metadata, num_splits = get_mla_metadata( seq_lens_device, - self.num_q_heads, + num_q_tokens_per_head_k, 1, # MQA for the decode path + is_fp8_kvcache=self.is_fp8_kvcache, ) # TODO: we can disambiguate between decode and mixed-prefill decode here @@ -142,6 +164,7 @@ def _build_decode( seq_lens=seq_lens_device, tile_scheduler_metadata=tile_scheduler_metadata, num_splits=num_splits, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, ) @@ -154,12 +177,12 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float], + logits_soft_cap: float | None, attn_type: str, - kv_sharing_target_layer_name: Optional[str], + kv_sharing_target_layer_name: str | None, # MLA Specific Arguments **mla_args, ) -> None: @@ -197,11 +220,11 @@ def __init__( def _forward_decode( self, - q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: FlashMLAMetadata, layer: AttentionLayer, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: # TODO: (zyongye) decode function for mla here assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None @@ -209,19 +232,56 @@ def _forward_decode( if type(q) is tuple: q = torch.cat(q, dim=-1) + # mypy assertion: q is now always a tensor assert isinstance(q, torch.Tensor) + + num_decodes = attn_metadata.num_decodes + q = reshape_query_for_spec_decode(q, num_decodes) + + tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata + num_splits = attn_metadata.decode.num_splits + if vllm_is_batch_invariant(): + device = q.device + dtype = torch.int32 + + B = q.shape[0] + # block_table shape: [batch_size, max_num_blocks_per_seq] + # The number of blocks per sequence is in the second dimension + topk = attn_metadata.decode.block_table.shape[-1] + B_TOPK = 64 + assert topk % B_TOPK == 0, f"topk ({topk}) must be divisible by {B_TOPK}" + end_block_idx = topk // B_TOPK + + # Single partition => num_sm_parts = 1 + # TileSchedulerMetaDataSize = 8, layout: + # [begin_idx, begin_block_idx, end_idx, end_block_idx, + # begin_n_split_idx, _, _, _] + tile_scheduler_metadata = torch.zeros((1, 8), dtype=dtype, device=device) + tile_scheduler_metadata[0, 0] = 0 # begin_idx + tile_scheduler_metadata[0, 1] = 0 # sched_begin_block_idx + tile_scheduler_metadata[0, 2] = B - 1 # end_idx + tile_scheduler_metadata[0, 3] = end_block_idx + tile_scheduler_metadata[0, 4] = 0 # begin_n_split_idx + # fields [5..7] stay 0 + + # Non-split path ignores num_splits, but the API requires it: + # zeros of length B+1 + num_splits = torch.zeros((B + 1,), dtype=dtype, device=device) + o, lse = flash_mla_with_kvcache( - q=q.unsqueeze(1), # Add seqlen dim of 1 (decode) + q=q, k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 block_table=attn_metadata.decode.block_table, cache_seqlens=attn_metadata.decode.seq_lens, head_dim_v=self.kv_lora_rank, - tile_scheduler_metadata=attn_metadata.decode.tile_scheduler_metadata, - num_splits=attn_metadata.decode.num_splits, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, softmax_scale=self.scale, causal=True, descale_q=layer._q_scale.reshape(1), descale_k=layer._k_scale.reshape(1), ) + o = reshape_attn_output_for_spec_decode(o) + return o, lse diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 49c29de35da1..141436e66c32 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -55,7 +55,7 @@ class FlashMLASparseBackend(AttentionBackend): @staticmethod def get_name() -> str: - return "FLASHMLA_SPARSE_VLLM_V1" + return "FLASHMLA_SPARSE" @staticmethod def get_metadata_cls() -> type[AttentionMetadata]: @@ -110,12 +110,12 @@ class FlashMLASparseMetadata: @dataclass class FP8KernelMetadata: - scheduler_metadata: Optional[torch.Tensor] + scheduler_metadata: torch.Tensor | None num_splits: torch.Tensor dummy_block_table: torch.Tensor cache_lens: torch.Tensor - fp8_extra_metadata: Optional[FP8KernelMetadata] = None + fp8_extra_metadata: FP8KernelMetadata | None = None @triton.jit @@ -373,14 +373,14 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float], + logits_soft_cap: float | None, attn_type: str, - kv_sharing_target_layer_name: Optional[str], + kv_sharing_target_layer_name: str | None, # MLA Specific Arguments - topk_indice_buffer: Optional[torch.Tensor] = None, + topk_indice_buffer: torch.Tensor | None = None, indexer: Optional["Indexer"] = None, **mla_args, ) -> None: @@ -466,9 +466,9 @@ def forward( k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, attn_metadata: FlashMLASparseMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use # MQA 576/512 approach for both prefill and decode diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 1344840af6a5..49009a939d0b 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -1,11 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import ClassVar import torch -from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionMetadata, + MultipleOf, +) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata @@ -47,6 +51,10 @@ def get_kv_cache_shape( def get_kv_cache_stride_order() -> tuple[int, ...]: return (0, 1, 2) + @classmethod + def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: + return [64] + @dataclass class DeepseekV32IndexerPrefillChunkMetadata: @@ -97,8 +105,8 @@ class DeepseekV32IndexerMetadata: num_prefills: int num_prefill_tokens: int - decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None - prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None + decode: DeepSeekV32IndexerDecodeMetadata | None = None + prefill: DeepseekV32IndexerPrefillMetadata | None = None # TODO (zyongye) optimize this, this is now vibe coded diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 54ebf071d96f..d935c02243bd 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional, Union +from typing import ClassVar import torch @@ -47,14 +47,14 @@ def get_builder_cls() -> type["AiterMLAMetadataBuilder"]: @dataclass class AiterMLADecodeMetadata(MLACommonDecodeMetadata): # The indptr of the paged kv cache, shape: [batch_size + 1] - paged_kv_indptr: Optional[torch.Tensor] = None + paged_kv_indptr: torch.Tensor | None = None # The page indices of the paged kv cache - paged_kv_indices: Optional[torch.Tensor] = None + paged_kv_indices: torch.Tensor | None = None # The number of entries in the last page of each request in # the paged kv cache, shape: [batch_size] - paged_kv_last_page_len: Optional[torch.Tensor] = None + paged_kv_last_page_len: torch.Tensor | None = None # The query indptr, shape : [num_decode + 1] - qo_indptr: Optional[torch.Tensor] = None + qo_indptr: torch.Tensor | None = None class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): @@ -116,6 +116,7 @@ def _build_decode( query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, + dcp_tot_seq_lens_device: torch.Tensor | None, ) -> AiterMLADecodeMetadata: page_size = self.kv_cache_spec.block_size block_table_bounds = (seq_lens_device + page_size - 1) // page_size @@ -174,6 +175,7 @@ def _build_decode( paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, qo_indptr=qo_indptr, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, ) return attn_metadata @@ -186,12 +188,12 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float], + logits_soft_cap: float | None, attn_type: str, - kv_sharing_target_layer_name: Optional[str], + kv_sharing_target_layer_name: str | None, # MLA Specific Arguments **mla_args, ) -> None: @@ -240,11 +242,11 @@ def _flash_attn_varlen_diff_headdims( def _forward_decode( self, - q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: AiterMLAMetadata, layer: AttentionLayer, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 3b6718c48d09..781f77e96319 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union import torch @@ -14,6 +13,9 @@ from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON from vllm.v1.attention.backends.mla.common import ( @@ -44,12 +46,12 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float], + logits_soft_cap: float | None, attn_type: str, - kv_sharing_target_layer_name: Optional[str], + kv_sharing_target_layer_name: str | None, # MLA Specific Arguments **mla_args, ) -> None: @@ -138,11 +140,11 @@ def _flash_attn_varlen_diff_headdims( def _forward_decode( self, - q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, layer: AttentionLayer, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None @@ -159,7 +161,9 @@ def _forward_decode( B, q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device ) lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device) - num_kv_splits = 4 # TODO: heuristic + + # For batch invariance, use only 1 split to ensure deterministic reduction + num_kv_splits = 1 if vllm_is_batch_invariant() else 4 # TODO(lucas) Allocate ahead of time attn_logits = torch.empty( diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 1622f852a952..28085cb1424b 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional import torch @@ -201,12 +200,12 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[int] = None, + kv_sharing_target_layer_name: int | None = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -242,9 +241,9 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: PallasMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -342,7 +341,7 @@ def write_to_kv_cache( slot_mapping: torch.Tensor, num_slices_per_kv_cache_update_block: int, num_kv_update_slices: torch.Tensor, - kv_cache_quantized_dtype: Optional[torch.dtype] = None, + kv_cache_quantized_dtype: torch.dtype | None = None, k_scale: float = 1.0, v_scale: float = 1.0, ) -> None: diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 348eca55eefb..f7a4114a0a70 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -3,7 +3,6 @@ """Attention layer with AiterFlashAttention.""" from dataclasses import dataclass -from typing import Optional import torch @@ -12,6 +11,7 @@ AttentionImpl, AttentionMetadata, AttentionType, + MultipleOf, ) from vllm.config import VllmConfig from vllm.logger import init_logger @@ -29,7 +29,7 @@ import aiter from vllm.triton_utils import tl, triton - from vllm.utils import direct_register_custom_op + from vllm.utils.torch_utils import direct_register_custom_op @triton.jit def _vllm_layout_trans_kernel( @@ -159,8 +159,8 @@ def flash_attn_varlen_func_impl( max_seqlen_q: int, max_seqlen_k: int, softmax_scale: float, - window_size: Optional[list[int]], # -1 means infinite context window - alibi_slopes: Optional[list[float]], + window_size: list[int] | None, # -1 means infinite context window + alibi_slopes: list[float] | None, block_table: torch.Tensor, k_scale: torch.Tensor, v_scale: torch.Tensor, @@ -208,8 +208,8 @@ def flash_attn_varlen_func_fake( max_seqlen_q: int, max_seqlen_k: int, softmax_scale: float, - window_size: Optional[list[int]], # -1 means infinite context window - alibi_slopes: Optional[list[float]], + window_size: list[int] | None, # -1 means infinite context window + alibi_slopes: list[float] | None, block_table: torch.Tensor, k_scale: torch.Tensor, v_scale: torch.Tensor, @@ -248,7 +248,7 @@ class AiterFlashAttentionMetadata: seq_lens: torch.Tensor slot_mapping: torch.Tensor block_table: torch.Tensor - cu_seq_lens: Optional[torch.Tensor] + cu_seq_lens: torch.Tensor | None # For cascade attention. use_cascade: bool @@ -282,7 +282,7 @@ def __init__( self.block_size = kv_cache_spec.block_size # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. - self.aot_sliding_window: Optional[tuple[int, int]] = None + self.aot_sliding_window: tuple[int, int] | None = None self.total_tokens: int = 0 def build_for_cudagraph_capture( @@ -359,6 +359,10 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_head_sizes(cls) -> list[int]: return [64, 128, 256] + @staticmethod + def get_supported_kernel_block_size() -> list[int | MultipleOf]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() @@ -407,12 +411,12 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[int] = None, + kv_sharing_target_layer_name: int | None = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -453,9 +457,9 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AiterFlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with AiterFlashAttention. @@ -481,7 +485,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index 235ea1c376ef..27b072106268 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -2,8 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with PagedAttention and Triton prefix prefill.""" -from typing import Optional - import torch from vllm import _custom_ops as ops @@ -70,13 +68,13 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[int] = None, - sinks: Optional[torch.Tensor] = None, + kv_sharing_target_layer_name: int | None = None, + sinks: torch.Tensor | None = None, ) -> None: super().__init__( num_heads, @@ -106,9 +104,9 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -132,7 +130,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) assert attn_metadata.use_cascade is False diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 10dd01f0a5aa..8b7ce90a3cca 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -3,7 +3,7 @@ """Attention layer with PagedAttention and Triton prefix prefill.""" from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import ClassVar import torch @@ -54,13 +54,13 @@ class RocmAttentionMetadata: # For cascade attention. use_cascade: bool common_prefix_len: int - cu_prefix_query_lens: Optional[torch.Tensor] - prefix_kv_lens: Optional[torch.Tensor] - suffix_kv_lens: Optional[torch.Tensor] + cu_prefix_query_lens: torch.Tensor | None + prefix_kv_lens: torch.Tensor | None + suffix_kv_lens: torch.Tensor | None # Optional aot scheduling - scheduler_metadata: Optional[torch.Tensor] = None - prefix_scheduler_metadata: Optional[torch.Tensor] = None + scheduler_metadata: torch.Tensor | None = None + prefix_scheduler_metadata: torch.Tensor | None = None class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]): @@ -217,13 +217,13 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[int] = None, - sinks: Optional[torch.Tensor] = None, + kv_sharing_target_layer_name: int | None = None, + sinks: torch.Tensor | None = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -273,9 +273,9 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -299,7 +299,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) assert attn_metadata.use_cascade is False diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index 74cfecca764e..22ad1054b35e 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional import torch @@ -30,12 +29,12 @@ class ShortConvAttentionMetadata: query_start_loc: torch.Tensor state_indices_tensor: torch.Tensor - has_initial_states_p: Optional[torch.Tensor] + has_initial_states_p: torch.Tensor | None # For causal_conv1d - nums_dict: Optional[dict] = None - batch_ptr: Optional[torch.Tensor] = None - token_chunk_offset_ptr: Optional[torch.Tensor] = None + nums_dict: dict | None = None + batch_ptr: torch.Tensor | None = None + token_chunk_offset_ptr: torch.Tensor | None = None class ShortConvAttentionMetadataBuilder( diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index a209bb79580c..ee6ead9ad9b3 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -14,6 +14,7 @@ AttentionImpl, AttentionMetadata, AttentionType, + MultipleOf, ) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig @@ -39,6 +40,10 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] + @staticmethod + def get_supported_kernel_block_size() -> list[int | MultipleOf]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() @@ -99,7 +104,7 @@ class TreeAttentionMetadata: num_prefills: int = 0 num_decodes: int = 0 - tree_attn_bias: Optional[torch.Tensor] = None + tree_attn_bias: torch.Tensor | None = None # Cached Prefill/decode metadata. _cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None @@ -262,8 +267,8 @@ def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]: def _prepare_tree_attn_bias( sorted_tree_choices: list[tuple[int, ...]], depth_counts: list[int], - dtype: Optional[torch.dtype], - device: Optional[torch.device], + dtype: torch.dtype | None, + device: torch.device | None, ) -> torch.Tensor: # +1 comes from the additional root node. tree_len = len(sorted_tree_choices) + 1 @@ -305,12 +310,12 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, + kv_sharing_target_layer_name: str | None = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -349,9 +354,9 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: TreeAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with TreeAttention. @@ -374,7 +379,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) # Cache the input KVs. key_cache, value_cache = kv_cache.unbind(0) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 9997ed16bed1..b1d34dbfd172 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -3,7 +3,7 @@ """High-Performance Triton-only Attention layer.""" from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import ClassVar import torch @@ -12,6 +12,7 @@ AttentionImpl, AttentionMetadata, AttentionType, + MultipleOf, ) from vllm.attention.ops.triton_reshape_and_cache_flash import ( triton_reshape_and_cache_flash, @@ -31,11 +32,6 @@ ) from vllm.v1.kv_cache_interface import AttentionSpec -if current_platform.is_cuda_alike(): - from vllm import _custom_ops as ops -elif current_platform.is_xpu(): - from vllm._ipex_ops import ipex_ops as ops - logger = init_logger(__name__) @@ -60,13 +56,13 @@ class TritonAttentionMetadata: # For cascade attention. use_cascade: bool common_prefix_len: int - cu_prefix_query_lens: Optional[torch.Tensor] - prefix_kv_lens: Optional[torch.Tensor] - suffix_kv_lens: Optional[torch.Tensor] + cu_prefix_query_lens: torch.Tensor | None + prefix_kv_lens: torch.Tensor | None + suffix_kv_lens: torch.Tensor | None # Optional aot scheduling - scheduler_metadata: Optional[torch.Tensor] = None - prefix_scheduler_metadata: Optional[torch.Tensor] = None + scheduler_metadata: torch.Tensor | None = None + prefix_scheduler_metadata: torch.Tensor | None = None class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]): @@ -157,6 +153,10 @@ class TritonAttentionBackend(AttentionBackend): def get_supported_dtypes(cls) -> list[torch.dtype]: return [torch.float16, torch.bfloat16, torch.float32] + @staticmethod + def get_supported_kernel_block_size() -> list[int | MultipleOf]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: # Triton Attention supports any head size above 32 @@ -205,19 +205,22 @@ class TritonAttentionImpl(AttentionImpl): def fused_output_quant_supported(self, quant_key: QuantKey): return quant_key == kFp8StaticTensorSym + def supports_quant_query_input(self) -> bool: + return current_platform.is_cuda() + def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[int] = None, - sinks: Optional[torch.Tensor] = None, + kv_sharing_target_layer_name: int | None = None, + sinks: torch.Tensor | None = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -267,9 +270,9 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: TritonAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with Paged Attention impl. in Triton. @@ -293,7 +296,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) assert attn_metadata.use_cascade is False @@ -333,19 +336,9 @@ def forward( if key_cache.dtype != self.fp8_dtype: key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) - num_tokens, num_heads, head_size = query.shape assert layer._q_scale_float == 1.0, ( "A non 1.0 q_scale is not currently supported." ) - if current_platform.is_cuda(): - # Skip Q quantization on ROCm and XPU, enable this on cuda - # only, since dequantizing back to f32 in the attention kernel - # is not supported. - query, _ = ops.scaled_fp8_quant( - query.reshape((num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale, - ) - query = query.reshape((num_tokens, num_heads, head_size)) cu_seqlens_q = attn_metadata.query_start_loc seqused_k = attn_metadata.seq_lens diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index add2c3cb8d59..cb5855548098 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,17 +4,15 @@ import enum import functools from abc import abstractmethod -from dataclasses import dataclass, fields, make_dataclass +from dataclasses import dataclass, field, fields, make_dataclass from typing import ( TYPE_CHECKING, Any, ClassVar, Generic, Literal, - Optional, Protocol, TypeVar, - Union, get_args, ) @@ -32,17 +30,17 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata -from vllm.attention.layer import Attention from vllm.distributed.kv_transfer.kv_connector.utils import ( get_kv_connector_cache_layout, ) from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.ubatch_utils import UBatchSlice logger = init_logger(__name__) KVCacheLayoutType = Literal["NHD", "HND"] -_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None +_KV_CACHE_LAYOUT_OVERRIDE: KVCacheLayoutType | None = None PAD_SLOT_ID = -1 @@ -87,11 +85,14 @@ class CommonAttentionMetadata: causal: bool = True # Needed by FastPrefillAttentionBuilder - logits_indices_padded: Optional[torch.Tensor] = None - num_logits_indices: Optional[int] = None + logits_indices_padded: torch.Tensor | None = None + num_logits_indices: int | None = None # Needed by CrossAttentionBuilder - encoder_seq_lens: Optional[np.ndarray] = None + encoder_seq_lens: np.ndarray | None = None + + dcp_local_seq_lens: torch.Tensor | None = None + """Sequence lengths of the local rank in decode context parallelism world""" def slice_query_start_locs( @@ -233,7 +234,7 @@ class AttentionCGSupport(enum.Enum): """Cudagraph always supported; supports mixed-prefill-decode""" UNIFORM_BATCH = 2 """Cudagraph supported for batches the only contain query lengths that are - the same, this can be used for spec-decode + the same, this can be used for spec-decode i.e. "decodes" are 1 + num_speculative_tokens""" UNIFORM_SINGLE_TOKEN_DECODE = 1 """Cudagraph supported for batches the only contain query_len==1 decodes""" @@ -247,7 +248,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. - reorder_batch_threshold: Optional[int] = None + reorder_batch_threshold: int | None = None @abstractmethod def __init__( @@ -344,6 +345,7 @@ def use_cascade_attention( use_sliding_window: bool, use_local_attention: bool, num_sms: int, + dcp_world_size: int, ) -> bool: return False @@ -392,9 +394,12 @@ class PerLayerParameters: """ window_left: int - logits_soft_cap: Optional[float] + logits_soft_cap: float | None sm_scale: float has_sinks: bool = False + # has same params for all layers + has_same_window_lefts: bool | None = field(default=None, compare=False) + has_same_all_params: bool | None = field(default=None, compare=False) def get_per_layer_parameters( @@ -405,7 +410,7 @@ def get_per_layer_parameters( to use during `plan`. """ - layers = get_layers_from_vllm_config(vllm_config, Attention, layer_names) + layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase, layer_names) per_layer_params: dict[str, PerLayerParameters] = {} for key, layer in layers.items(): @@ -446,20 +451,12 @@ def infer_global_hyperparameters( param_sets = list(per_layer_params.values()) global_params = param_sets[0] - # trtllm attention doesn't need global hyper params so disable the check - if not envs.VLLM_USE_TRTLLM_ATTENTION: - for params in param_sets: - if params.window_left != global_params.window_left: - raise ValueError( - "Window left is not the same for all layers. " - "One potential fix is to set disable_sliding_window=True" - ) - assert params == global_params, ( - "FlashInfer backend currently only supports models in which all" - "layers share the same values " - "for the following hyperparameters:" - "`window_left`, `logits_soft_cap`, `sm_scale`." - ) + global_params.has_same_window_lefts = all( + params.window_left == global_params.window_left for params in param_sets + ) + global_params.has_same_all_params = all( + params == global_params for params in param_sets + ) return global_params @@ -875,7 +872,7 @@ def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tens KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [ - ("logits_indices_padded", Optional[torch.Tensor], None), + ("logits_indices_padded", torch.Tensor | None, None), ("num_logits_indices", int, 0), ] @@ -925,8 +922,8 @@ class KVSharingFastPrefillAttentionMetadata( ): def __init__(self, metadata, common_attn_metadata): # Shallow copy all fields in metadata cls - for field in fields(metadata.__class__): - setattr(self, field.name, getattr(metadata, field.name)) + for _field in fields(metadata.__class__): + setattr(self, _field.name, getattr(metadata, _field.name)) # Set additional fields that will be used in model code assert ( diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index b21562fac741..457b15ebdd82 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -12,6 +12,7 @@ AttentionImpl, AttentionMetadata, AttentionType, + MultipleOf, ) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig @@ -80,6 +81,10 @@ def get_supported_head_sizes(cls) -> list[int]: 256, ] + @staticmethod + def get_supported_kernel_block_size() -> list[int | MultipleOf]: + return [MultipleOf(16)] + @classmethod def validate_head_size(cls, head_size: int) -> None: supported_head_sizes = cls.get_supported_head_sizes() @@ -275,12 +280,12 @@ def __init__( head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, + kv_sharing_target_layer_name: str | None = None, ) -> None: if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported in V0.") @@ -323,9 +328,9 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: XFormersAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with XFormers. @@ -349,7 +354,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) # Cache the input KVs. key_cache, value_cache = kv_cache.unbind(0) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index ddfd94322737..15c06a0b107d 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable -from typing import Any, Optional, Union +from collections.abc import Iterable, Sequence +from typing import Any from vllm.distributed.kv_events import ( MEDIUM_GPU, @@ -51,10 +51,10 @@ class BlockHashToBlockMap: def __init__(self): self._cache: dict[ - BlockHashWithGroupId, Union[KVCacheBlock, dict[int, KVCacheBlock]] + BlockHashWithGroupId, KVCacheBlock | dict[int, KVCacheBlock] ] = {} - def get_one_block(self, key: BlockHashWithGroupId) -> Optional[KVCacheBlock]: + def get_one_block(self, key: BlockHashWithGroupId) -> KVCacheBlock | None: """ Gets any block with the given block hash key. """ @@ -85,7 +85,7 @@ def insert(self, key: BlockHashWithGroupId, block: KVCacheBlock) -> None: else: self._unexpected_blocks_type(blocks) - def pop(self, key: BlockHashWithGroupId, block_id: int) -> Optional[KVCacheBlock]: + def pop(self, key: BlockHashWithGroupId, block_id: int) -> KVCacheBlock | None: """ Checks if block_hash exists and pop block_id from the cache """ @@ -168,7 +168,7 @@ def __init__( def get_cached_block( self, block_hash: BlockHash, kv_cache_group_ids: list[int] - ) -> Optional[list[KVCacheBlock]]: + ) -> list[KVCacheBlock] | None: """Get the cached block by the block hash for each group in `kv_cache_group_ids`, or None if cache miss for any group. If there are duplicated blocks, we return the first block in the cache. @@ -225,7 +225,7 @@ def cache_full_blocks( assert len(request.block_hashes) >= num_full_blocks new_block_hashes = request.block_hashes[num_cached_blocks:] - new_hashes: Optional[list[ExternalBlockHash]] = ( + new_hashes: list[ExternalBlockHash] | None = ( [] if self.enable_kv_cache_events else None ) for i, blk in enumerate(new_full_blocks): @@ -243,7 +243,7 @@ def cache_full_blocks( if self.enable_kv_cache_events: if num_cached_blocks == 0: - parent_block_hash: Optional[ExternalBlockHash] = None + parent_block_hash: ExternalBlockHash | None = None else: parent_block = blocks[num_cached_blocks - 1] assert parent_block.block_hash is not None @@ -328,7 +328,7 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: ) return True - def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None: + def touch(self, blocks: tuple[Sequence[KVCacheBlock], ...]) -> None: """Touch a block increases its reference count by 1, and may remove the block from the free queue. This is used when a block is hit by another request with the same prefix. diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index c70025992e70..3959e9a59a53 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -264,8 +264,8 @@ def compute_encoder_budget( from the input sequence. """ if mm_registry.supports_multimodal_inputs(model_config): - max_tokens_by_modality = ( - mm_registry.get_max_tokens_per_item_by_nonzero_modality(model_config) + max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality( + model_config ) return compute_mm_encoder_budget( diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index ef6da9adeea7..137e5e0cdb6d 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Optional +from collections.abc import Sequence from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock @@ -52,7 +52,7 @@ def get_num_blocks_to_allocate( self, request_id: str, num_tokens: int, - new_computed_blocks: tuple[list[KVCacheBlock], ...], + new_computed_blocks: tuple[Sequence[KVCacheBlock], ...], num_encoder_tokens: int, ) -> int: """ @@ -85,7 +85,7 @@ def get_num_blocks_to_allocate( return num_blocks_to_allocate def save_new_computed_blocks( - self, request_id: str, new_computed_blocks: tuple[list[KVCacheBlock], ...] + self, request_id: str, new_computed_blocks: tuple[Sequence[KVCacheBlock], ...] ) -> None: """ Add the new computed blocks to the request. @@ -320,8 +320,8 @@ def verify_and_split_kv_cache_groups(self) -> None: one of them is full attention. Then, split the kv cache groups into full attention groups and other groups. """ - full_attention_spec: Optional[FullAttentionSpec] = None - other_spec: Optional[KVCacheSpec] = None + full_attention_spec: FullAttentionSpec | None = None + other_spec: KVCacheSpec | None = None self.full_attention_group_ids: list[int] = [] self.other_group_ids: list[int] = [] for i, g in enumerate(self.kv_cache_config.kv_cache_groups): diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index b74ccd30b97b..bb8cec91f36d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools +from collections.abc import Sequence from dataclasses import dataclass -from typing import Literal, Optional, overload +from typing import Literal, overload from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger @@ -23,7 +25,7 @@ class KVCacheBlocks: structure from the Scheduler. """ - blocks: tuple[list[KVCacheBlock], ...] + blocks: tuple[Sequence[KVCacheBlock], ...] """ `blocks[i][j]` refers to the i-th kv_cache_group and the j-th block of tokens.We don't use block of @@ -31,12 +33,20 @@ class KVCacheBlocks: kv_cache_groups have the same number of blocks, which is true for now but will be broken if we want to give different block_size to different kv_cache_groups in the future. + + Each single type KVCacheBlocks could be represented as: + - list[KVCacheBlock] for more than one KVCacheBlock + - an empty tuple for requests without KVCacheBlock + (a precomputed KVCacheBlocks is in KVCacheManager to avoid GC overhead) """ def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": """Adds two KVCacheBlocks instances.""" return KVCacheBlocks( - tuple(blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)) + tuple( + list(itertools.chain(blk1, blk2)) + for blk1, blk2 in zip(self.blocks, other.blocks) + ) ) @overload @@ -49,12 +59,12 @@ def get_block_ids( def get_block_ids( self, allow_none: Literal[True] = True, - ) -> Optional[tuple[list[int], ...]]: ... + ) -> tuple[list[int], ...] | None: ... def get_block_ids( self, allow_none: bool = False, - ) -> Optional[tuple[list[int], ...]]: + ) -> tuple[list[int], ...] | None: """ Converts the KVCacheBlocks instance to block_ids. @@ -74,8 +84,10 @@ def get_unhashed_block_ids(self) -> list[int]: return [block.block_id for block in self.blocks[0] if block.block_hash is None] def new_empty(self) -> "KVCacheBlocks": - """Creates a new KVCacheBlocks instance with no blocks.""" - return KVCacheBlocks(tuple([] for _ in range(len(self.blocks)))) + """ + Creates a new KVCacheBlocks instance with no blocks. + """ + return KVCacheBlocks(tuple(() for _ in range(len(self.blocks)))) class KVCacheManager: @@ -97,7 +109,7 @@ def __init__( # FIXME: make prefix cache stats conditional on log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None - self.block_size: Optional[int] = None + self.block_size: int | None = None if self.enable_caching: assert ( len( @@ -131,6 +143,15 @@ def __init__( self.block_pool = self.coordinator.block_pool self.kv_cache_config = kv_cache_config + # Pre-constructed KVCacheBlocks with no blocks, callers should use this + # via create_kv_cache_blocks instead of creating new ones to avoid GC + # overhead. + # + # We use nested tuples to ensure the empty KVCacheBlocks is immutable. + self.empty_kv_cache_blocks = KVCacheBlocks( + tuple(() for _ in range(self.num_kv_cache_groups)) + ) + @property def usage(self) -> float: """Get the KV cache usage. @@ -140,7 +161,7 @@ def usage(self) -> float: """ return self.block_pool.get_usage() - def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: + def make_prefix_cache_stats(self) -> PrefixCacheStats | None: """Get (and reset) the prefix cache stats. Returns: @@ -170,7 +191,7 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None ): - return self.create_empty_block_list(), 0 + return self.empty_kv_cache_blocks, 0 # NOTE: When all tokens hit the cache, we must recompute the last token # to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1. @@ -187,29 +208,24 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: if self.log_stats: assert self.prefix_cache_stats is not None - if request.num_preemptions > 0: - # Previously preempted request - self.prefix_cache_stats.preempted_requests += 1 - self.prefix_cache_stats.preempted_queries += request.num_tokens - self.prefix_cache_stats.preempted_hits += num_new_computed_tokens - else: - # New request - self.prefix_cache_stats.requests += 1 - self.prefix_cache_stats.queries += request.num_tokens - self.prefix_cache_stats.hits += num_new_computed_tokens - - return KVCacheBlocks(computed_blocks), num_new_computed_tokens + self.prefix_cache_stats.record( + num_tokens=request.num_tokens, + num_hits=num_new_computed_tokens, + preempted=request.num_preemptions > 0, + ) + + return self.create_kv_cache_blocks(computed_blocks), num_new_computed_tokens def allocate_slots( self, request: Request, num_new_tokens: int, num_new_computed_tokens: int = 0, - new_computed_blocks: Optional[KVCacheBlocks] = None, + new_computed_blocks: KVCacheBlocks | None = None, num_lookahead_tokens: int = 0, delay_cache_blocks: bool = False, num_encoder_tokens: int = 0, - ) -> Optional[KVCacheBlocks]: + ) -> KVCacheBlocks | None: """Add slots for a request with new tokens to append. Args: @@ -251,9 +267,7 @@ def allocate_slots( if new_computed_blocks is not None: new_computed_block_list = new_computed_blocks.blocks else: - new_computed_block_list = tuple( - [] for _ in range(len(self.kv_cache_config.kv_cache_groups)) - ) + new_computed_block_list = self.empty_kv_cache_blocks.blocks # Free the blocks that are skipped during the attention computation # (e.g., tokens outside the sliding window). @@ -305,7 +319,7 @@ def allocate_slots( # P/D: delay caching blocks if we have to recv from # remote. Update state for locally cached blocks. if not self.enable_caching or delay_cache_blocks: - return KVCacheBlocks(new_blocks) + return self.create_kv_cache_blocks(new_blocks) # NOTE(woosuk): We want to commit (cache) up to num_computed_tokens + # num_new_tokens, but must exclude "non-committable" tokens (e.g., @@ -316,7 +330,7 @@ def allocate_slots( ) self.coordinator.cache_blocks(request, num_tokens_to_cache) - return KVCacheBlocks(new_blocks) + return self.create_kv_cache_blocks(new_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. @@ -388,7 +402,7 @@ def take_events(self) -> list[KVCacheEvent]: def get_blocks(self, request_id: str) -> KVCacheBlocks: """Get the blocks of a request.""" - return KVCacheBlocks(self.coordinator.get_blocks(request_id)) + return self.create_kv_cache_blocks(self.coordinator.get_blocks(request_id)) def get_block_ids(self, request_id: str) -> tuple[list[int], ...]: """Get the block ids of a request.""" @@ -399,6 +413,8 @@ def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: if self.enable_caching: self.coordinator.cache_blocks(request, num_computed_tokens) - def create_empty_block_list(self) -> KVCacheBlocks: - """Creates a new KVCacheBlocks instance with no blocks.""" - return KVCacheBlocks(tuple([] for _ in range(self.num_kv_cache_groups))) + def create_kv_cache_blocks( + self, blocks: tuple[list[KVCacheBlock], ...] + ) -> KVCacheBlocks: + # Only create new KVCacheBlocks for non-empty blocks + return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index bb0b7e259b41..584904daea8b 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -4,15 +4,17 @@ import copy import os -from collections import defaultdict, deque -from collections.abc import Iterable, Sequence +from collections import defaultdict +from collections.abc import Callable, Iterable, Sequence from dataclasses import dataclass -from typing import Any, Callable, NewType, Optional, Union +from typing import Any, NewType, TypeAlias from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import GiB_bytes, cdiv, sha256_cbor +from vllm.utils import cdiv +from vllm.utils.hashing import sha256_cbor +from vllm.utils.mem_constants import GiB_bytes from vllm.v1.kv_cache_interface import ( ChunkedLocalAttentionSpec, FullAttentionSpec, @@ -23,29 +25,29 @@ SlidingWindowSpec, UniformTypeKVCacheSpecs, ) -from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request +from vllm.v1.utils import tensor_data # BlockHash represents the hash of a single KV-cache block used for -# prefix caching. Treating it as a distinct type from ``bytes`` helps +# prefix caching. Treating it as a distinct type from `bytes` helps # catch accidental misuse when passing around raw byte strings. BlockHash = NewType("BlockHash", bytes) -# ``BlockHashWithGroupId`` combines a ``BlockHash`` with its KV cache group ID. +# `BlockHashWithGroupId` combines a `BlockHash` with its KV cache group ID. # It is represented as raw bytes for compactness and efficiency. The helper -# functions below pack/unpack the ``BlockHash`` and group id into/from the key. +# functions below pack/unpack the `BlockHash` and group id into/from the key. BlockHashWithGroupId = NewType("BlockHashWithGroupId", bytes) # ExternalBlockHash is used for reproducible prefix-cache block hashing. -# It's a union of ``bytes`` and ``int`` to keep backward compatibility +# It's a union of `bytes` and `int` to keep backward compatibility # after we default block hashing to use sha256 bytes. -ExternalBlockHash = Union[bytes, int] +ExternalBlockHash: TypeAlias = bytes | int def make_block_hash_with_group_id( block_hash: BlockHash, group_id: int ) -> BlockHashWithGroupId: - """Pack a ``BlockHash`` and group id into a ``BlockHashWithGroupId``. + """Pack a `BlockHash` and group id into a `BlockHashWithGroupId`. The group id is encoded using 4 bytes in big-endian order and appended to the block hash bytes. This representation avoids creating tuples while @@ -55,12 +57,12 @@ def make_block_hash_with_group_id( def get_block_hash(key: BlockHashWithGroupId) -> BlockHash: - """Extract the ``BlockHash`` from a ``BlockHashWithGroupId``.""" + """Extract the `BlockHash` from a `BlockHashWithGroupId`.""" return BlockHash(key[:-4]) def get_group_id(key: BlockHashWithGroupId) -> int: - """Extract the group id from a ``BlockHashWithGroupId``.""" + """Extract the group id from a `BlockHashWithGroupId`.""" return int.from_bytes(key[-4:], "big", signed=False) @@ -101,78 +103,6 @@ def init_none_hash(hash_fn: Callable[[Any], bytes]): NONE_HASH = BlockHash(hash_fn(hash_seed)) -class PrefixCachingMetrics: - """Metrics for prefix caching with a hit rate of the max recent N requests. - - Args: - max_recent_requests: The number of the max recent requests to aggregate. - Defaults to 1000. - """ - - def __init__(self, max_recent_requests: int = 1000): - self.max_recent_requests = max_recent_requests - # The current aggregated values. - self.aggregated_requests = 0 - self.aggregated_query_total = 0 - self.aggregated_query_hit = 0 - # A deque of (requests, queries, hits) for the most recent requests. - self.query_queue: deque[tuple[int, int, int]] = deque() - - def observe(self, stats: PrefixCacheStats): - """Observe the prefix caching for a set of requests. - - This function is called with information gathered when new requests - are being scheduled and are looking for computed blocks. - - When there are more than `max_recent_requests` requests, the oldest set - of requests are removed from the metrics. - - Args: - stats: The prefix cache stats. - """ - # reset_prefix_cache was invoked before the current update. - # Reset the metrics before aggregating the current stats. - if stats.reset: - self.reset() - - # DO NOT appending empty stats to avoid helpful info get kicked out - # due to sliding window. - if stats.requests == 0: - return - - # Update the metrics. - self.query_queue.append((stats.requests, stats.queries, stats.hits)) - self.aggregated_requests += stats.requests - self.aggregated_query_total += stats.queries - self.aggregated_query_hit += stats.hits - - # Remove the oldest stats until number of requests does not exceed - # the limit. - # NOTE: We preserve the latest added stats regardless. - while ( - len(self.query_queue) > 1 - and self.aggregated_requests > self.max_recent_requests - ): - old_requests, old_queries, old_hits = self.query_queue.popleft() - self.aggregated_requests -= old_requests - self.aggregated_query_total -= old_queries - self.aggregated_query_hit -= old_hits - - def reset(self): - """Reset the metrics.""" - self.aggregated_requests = 0 - self.aggregated_query_total = 0 - self.aggregated_query_hit = 0 - self.query_queue.clear() - - @property - def hit_rate(self) -> float: - """Calculate the hit rate for the past N requests.""" - if self.aggregated_query_total == 0: - return 0.0 - return self.aggregated_query_hit / self.aggregated_query_total - - @dataclass class KVCacheBlock: """KV-cache block metadata.""" @@ -183,18 +113,18 @@ class KVCacheBlock: ref_cnt: int = 0 # The hash key (block hash + group id) of the block, only available # when the block is full and cached. - _block_hash: Optional[BlockHashWithGroupId] = None + _block_hash: BlockHashWithGroupId | None = None # Used to construct a doubly linked list for free blocks. # These two attributes should only be manipulated by FreeKVCacheBlockQueue. - prev_free_block: Optional["KVCacheBlock"] = None - next_free_block: Optional["KVCacheBlock"] = None + prev_free_block: "KVCacheBlock | None" = None + next_free_block: "KVCacheBlock | None" = None # Whether the block is a null block that should never be cached. is_null: bool = False @property - def block_hash(self) -> Optional[BlockHashWithGroupId]: + def block_hash(self) -> BlockHashWithGroupId | None: return self._block_hash @block_hash.setter @@ -444,7 +374,7 @@ def need_extra_keys(request: Request) -> bool: """ # Multimodal requests need to include the MM hash. - # LoRA requests need to include the LoRA ID. + # LoRA requests need to include the LoRA name. # Request with provided cache salt need to include the salt. return ( bool(request.mm_features) @@ -517,26 +447,48 @@ def _gen_mm_extra_hash_keys( return extra_keys, curr_mm_idx -def _gen_lora_extra_hash_keys(request: Request) -> list[int]: +def _gen_lora_extra_hash_keys(request: Request) -> list[str]: """Generate extra keys related to LoRA for block hash computation. Args: request: The request object. Returns: - Return LoRA id of the request if it is a LoRA request. Return empty + Return LoRA name of the request if it is a LoRA request. Return empty list otherwise. """ if not request.lora_request: return [] - return [request.lora_request.lora_int_id] + return [request.lora_request.lora_name] + + +def _gen_prompt_embeds_extra_hash_keys( + request: Request, start_token_idx: int, end_token_idx: int +) -> list[bytes]: + """Generate extra keys related to prompt embeds for block hash computation. + + Args: + request: The request object. + start_token_idx: The start token index of the block. + end_token_idx: The end token index of the block. + + Returns: + Return prompt embeddings data of the request if it has prompt embeds. + Return empty list otherwise. + """ + if request.prompt_embeds is None: + return [] + block_prompt_embeds = request.prompt_embeds[start_token_idx:end_token_idx] + embeds_bytes = tensor_data(block_prompt_embeds).tobytes() + return [embeds_bytes] def generate_block_hash_extra_keys( request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int -) -> tuple[Optional[tuple[Any, ...]], int]: +) -> tuple[tuple[Any, ...] | None, int]: """Generate extra keys for the block hash. The extra keys can come from - the multi-modal inputs and request specific metadata (e.g., LoRA ID). + the multi-modal inputs, request specific metadata (e.g., LoRA names), and + data from prompt embeddings. Args: request: The request object. @@ -551,12 +503,17 @@ def generate_block_hash_extra_keys( mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys( request, start_token_idx, end_token_idx, start_mm_idx ) - lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request) + lora_extra_keys: list[str] = _gen_lora_extra_hash_keys(request) cache_salt_keys: list[str] = ( [request.cache_salt] if (start_token_idx == 0 and request.cache_salt) else [] ) + prompt_embeds_keys = _gen_prompt_embeds_extra_hash_keys( + request, start_token_idx, end_token_idx + ) - extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys + extra_keys: list[Any] = ( + lora_extra_keys + mm_extra_keys + cache_salt_keys + prompt_embeds_keys + ) if not extra_keys: return None, new_start_mm_idx @@ -566,9 +523,9 @@ def generate_block_hash_extra_keys( def hash_block_tokens( hash_function: Callable[[Any], bytes], - parent_block_hash: Optional[BlockHash], + parent_block_hash: BlockHash | None, curr_block_token_ids: Sequence[int], - extra_keys: Optional[tuple[Any, ...]] = None, + extra_keys: tuple[Any, ...] | None = None, ) -> BlockHash: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for @@ -1269,7 +1226,7 @@ def _report_kv_cache_config( vllm_config.parallel_config.decode_context_parallel_size, ) num_tokens_str = f"{num_tokens:,}" - logger.info("GPU KV cache size: %s tokens", num_tokens_str) + logger.info_once("GPU KV cache size: %s tokens", num_tokens_str, scope="local") max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" max_concurrency = get_max_concurrency_for_kv_cache_config( vllm_config, kv_cache_config diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index 968b4db530bf..da6e4aa2996b 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index b92ef395e9b7..c36483203343 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod from collections.abc import Iterable -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 @@ -80,7 +80,7 @@ def add_request(self, request: "Request") -> None: @abstractmethod def finish_requests( self, - request_ids: Union[str, Iterable[str]], + request_ids: str | Iterable[str], finished_status: "RequestStatus", ) -> None: """Finish the requests in the scheduler's internal queue. If the request diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index cbce91b990a1..035394f04530 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - from dataclasses import dataclass from typing import TYPE_CHECKING @@ -19,6 +17,13 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.request import Request +else: + KVConnectorMetadata = object + LoRARequest = object + MultiModalFeatureSpec = object + PoolingParams = object + SamplingParams = object + Request = object @bc_linter_include @@ -32,14 +37,14 @@ class NewRequestData: block_ids: tuple[list[int], ...] num_computed_tokens: int lora_request: LoRARequest | None - prompt_embeds: torch.Tensor | None = None + prompt_embeds: "torch.Tensor | None" = None @classmethod def from_request( cls, request: Request, block_ids: tuple[list[int], ...], - ) -> NewRequestData: + ) -> "NewRequestData": return cls( req_id=request.request_id, prompt_token_ids=request.prompt_token_ids, @@ -98,6 +103,9 @@ class CachedRequestData: # NOTE(woosuk): new_token_ids is only used for pipeline parallelism. # When PP is not used, new_token_ids will be empty. new_token_ids: list[list[int]] + # If resumed_from_preemption is True, propogate the token ids to the + # connector, otherwise will be empty. + resumed_req_token_ids: list[list[int] | None] new_block_ids: list[tuple[list[int], ...] | None] num_computed_tokens: list[int] num_output_tokens: list[int] @@ -107,11 +115,12 @@ def num_reqs(self) -> int: return len(self.req_ids) @classmethod - def make_empty(cls) -> CachedRequestData: + def make_empty(cls) -> "CachedRequestData": return cls( req_ids=[], resumed_from_preemption=[], new_token_ids=[], + resumed_req_token_ids=[], new_block_ids=[], num_computed_tokens=[], num_output_tokens=[], @@ -156,11 +165,12 @@ class SchedulerOutput: # freed from the encoder cache. free_encoder_mm_hashes: list[str] - # Dict of request ids to their index within the batch - # for filling the next token bitmask - structured_output_request_ids: dict[str, int] + # ids of structured outputs requests included in the bitmask, in the + # same order as the corresponding stacked rows of the bitmask. + # There may be more than one row per request in the case of speculative decoding. + structured_output_request_ids: list[str] # the bitmask for the whole batch - grammar_bitmask: npt.NDArray[np.int32] | None + grammar_bitmask: "npt.NDArray[np.int32] | None" # KV Cache Connector metadata. kv_connector_metadata: KVConnectorMetadata | None = None diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index 33e5ec72ebd7..7bc1010db23a 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import heapq from abc import ABC, abstractmethod from collections import deque @@ -43,7 +41,7 @@ def prepend_request(self, request: Request) -> None: pass @abstractmethod - def prepend_requests(self, requests: RequestQueue) -> None: + def prepend_requests(self, requests: "RequestQueue") -> None: """Prepend all requests from another queue to the front of this queue.""" pass diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f81750047ecc..7afee15a2da6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1,13 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from __future__ import annotations - +import copy import itertools import time from collections import defaultdict from collections.abc import Iterable -from typing import Any, Union +from typing import TYPE_CHECKING, Any from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch @@ -15,6 +13,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import ( KVConnectorBase_V1, KVConnectorRole, + supports_hma, ) from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.logger import init_logger @@ -30,12 +29,16 @@ from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.metrics.stats import SchedulerStats +from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager +if TYPE_CHECKING: + import numpy as np + import numpy.typing as npt + logger = init_logger(__name__) @@ -45,6 +48,7 @@ def __init__( vllm_config: VllmConfig, kv_cache_config: KVCacheConfig, structured_output_manager: StructuredOutputManager, + block_size: int, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, include_finished_set: bool = False, log_stats: bool = False, @@ -81,17 +85,19 @@ def __init__( # will have a corresponding KVConnector with Role=WORKER. # KV Connector pushes/pull of remote KVs for P/D and offloading. self.connector = None + self.connector_prefix_cache_stats: PrefixCacheStats | None = None if self.vllm_config.kv_transfer_config is not None: - assert len(self.kv_cache_config.kv_cache_groups) == 1, ( - "Multiple KV cache groups are not currently supported " - "with KV connectors" - ) assert not self.is_encoder_decoder, ( "Encoder-decoder models are not currently supported with KV connectors" ) + + connector_vllm_config = copy.copy(self.vllm_config) + connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config) self.connector = KVConnectorFactory.create_connector( - config=self.vllm_config, role=KVConnectorRole.SCHEDULER + config=connector_vllm_config, role=KVConnectorRole.SCHEDULER ) + if self.log_stats: + self.connector_prefix_cache_stats = PrefixCacheStats() self.kv_event_publisher = EventPublisherFactory.create( self.kv_events_config, @@ -101,15 +107,8 @@ def __init__( num_gpu_blocks = self.cache_config.num_gpu_blocks assert num_gpu_blocks is not None and num_gpu_blocks > 0 - self.block_size = self.cache_config.block_size - + self.block_size = block_size self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size - # Note(hc): The scheduler’s block_size must be multiplied - # by dcp_world_size, since block hashes are computed on the - # original full token sequence at a granularity of - # original_block_size × dcp_world_size. - if self.dcp_world_size > 1: - self.block_size *= self.dcp_world_size # req_id -> Request self.requests: dict[str, Request] = {} @@ -279,6 +278,10 @@ def schedule(self) -> SchedulerOutput: self.running.remove(preempted_req) if preempted_req in scheduled_running_reqs: scheduled_running_reqs.remove(preempted_req) + token_budget += num_scheduled_tokens[preempted_req.request_id] + req_to_new_blocks.pop(preempted_req.request_id) + num_scheduled_tokens.pop(preempted_req.request_id) + req_index -= 1 else: preempted_req = self.running.pop() @@ -426,9 +429,7 @@ def schedule(self) -> SchedulerOutput: # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. else: - new_computed_blocks = ( - self.kv_cache_manager.create_empty_block_list() - ) + new_computed_blocks = self.kv_cache_manager.empty_kv_cache_blocks num_new_local_computed_tokens = 0 num_computed_tokens = request.num_computed_tokens @@ -528,6 +529,9 @@ def schedule(self) -> SchedulerOutput: new_computed_blocks + new_blocks, num_external_computed_tokens, ) + self._update_connector_prefix_cache_stats( + request, num_external_computed_tokens + ) # Request was already popped from self.waiting # unless it was re-added above due to new_blocks being None. @@ -615,11 +619,8 @@ def schedule(self) -> SchedulerOutput: scheduled_spec_decode_tokens, req_to_new_blocks, ) - scheduled_requests = ( - scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs - ) structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask( - scheduled_requests, scheduled_spec_decode_tokens + num_scheduled_tokens.keys(), scheduled_spec_decode_tokens ) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, @@ -709,10 +710,15 @@ def _make_cached_request_data( req_ids: list[str] = [] new_token_ids: list[list[int]] = [] new_block_ids: list[tuple[list[int], ...] | None] = [] + resumed_req_token_ids: list[list[int] | None] = [] num_computed_tokens: list[int] = [] num_output_tokens: list[int] = [] - for req in itertools.chain(running_reqs, resumed_reqs): + # Because resumed_reqs is usually empty, it is more efficient to do + # in-place appending so that we don't need to allocate a new list. + resumed_from_preemption = [False] * len(running_reqs) + resumed_from_preemption += [True] * len(resumed_reqs) + for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)): req_id = req.request_id req_ids.append(req_id) num_tokens = num_scheduled_tokens[req_id] - len( @@ -728,20 +734,25 @@ def _make_cached_request_data( req.num_computed_tokens : req.num_computed_tokens + num_tokens ] new_token_ids.append(token_ids) + resumed_token_ids = None + if resumed_from_preemption[idx]: + resumed_token_ids = req.all_token_ids[ + : req.num_computed_tokens + num_tokens + ] + resumed_req_token_ids.append(resumed_token_ids) new_block_ids.append( req_to_new_blocks[req_id].get_block_ids(allow_none=True) ) num_computed_tokens.append(req.num_computed_tokens) - num_output_tokens.append(req.num_output_tokens) - # Because resumed_reqs is usually empty, it is more efficient to do - # in-place appending so that we don't need to allocate a new list. - resumed_from_preemption = [False] * len(running_reqs) - resumed_from_preemption += [True] * len(resumed_reqs) + num_output_tokens.append( + req.num_output_tokens + req.num_output_placeholders + ) return CachedRequestData( req_ids=req_ids, resumed_from_preemption=resumed_from_preemption, new_token_ids=new_token_ids, + resumed_req_token_ids=resumed_req_token_ids, new_block_ids=new_block_ids, num_computed_tokens=num_computed_tokens, num_output_tokens=num_output_tokens, @@ -873,32 +884,28 @@ def _try_schedule_encoder_inputs( def get_grammar_bitmask( self, - requests: list[Request], + scheduled_request_ids: Iterable[str], scheduled_spec_decode_tokens: dict[str, list[int]], - ): - # NOTE: structured_output_request_ids maps - # a request's (request that uses structured output) - # request_id to its index in the batch. - # This will help us determine to slice the grammar bitmask - # and only applies valid mask for requests that - # uses structured decoding. - structured_output_request_ids: dict[str, int] = {} - for i, req in enumerate(requests): - if req.use_structured_output: - # PERF: in case of chunked prefill, - # request might not include any new tokens. - # Therefore, we might introduce some additional - # cycle to fill in the bitmask, which could be a big no-op. - structured_output_request_ids[req.request_id] = i - + ) -> tuple[list[str], "npt.NDArray[np.int32] | None"]: + # Collect list of scheduled request ids that use structured output. + # The corresponding rows of the bitmask will be in this order. + # PERF: in case of chunked prefill, + # request might not include any new tokens. + # Therefore, we might introduce some additional + # cycle to fill in the bitmask, which could be a big no-op. + structured_output_request_ids = [ + req_id + for req_id in scheduled_request_ids + if (req := self.requests.get(req_id)) and req.use_structured_output + ] if not structured_output_request_ids: - bitmask = None - else: - bitmask = self.structured_output_manager.grammar_bitmask( - self.requests, - structured_output_request_ids, - scheduled_spec_decode_tokens, - ) + return structured_output_request_ids, None + + bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + structured_output_request_ids, + scheduled_spec_decode_tokens, + ) return structured_output_request_ids, bitmask def update_from_output( @@ -919,6 +926,10 @@ def update_from_output( kv_connector_stats = ( kv_connector_output.kv_connector_stats if kv_connector_output else None ) + if kv_connector_stats and self.connector: + stats = self.connector.get_kv_connector_stats() + if stats: + kv_connector_stats = kv_connector_stats.aggregate(stats) failed_kv_load_req_ids = None if kv_connector_output and kv_connector_output.invalid_block_ids: @@ -1006,12 +1017,10 @@ def update_from_output( new_logprobs = logprobs.slice(req_index, req_index + 1) if new_token_ids and self.structured_output_manager.should_advance(request): - # NOTE: structured_output_request - # should not be None if use_structured_output, we have - # checked above, so safe to ignore type warning - request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] - req_id, new_token_ids - ) + struct_output_request = request.structured_output_request + assert struct_output_request is not None + assert struct_output_request.grammar is not None + struct_output_request.grammar.accept_tokens(req_id, new_token_ids) if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] @@ -1164,7 +1173,7 @@ def add_request(self, request: Request) -> None: def finish_requests( self, - request_ids: Union[str, Iterable[str]], + request_ids: str | Iterable[str], finished_status: RequestStatus, ) -> None: """Handles the finish signal from outside the scheduler. @@ -1185,7 +1194,7 @@ def finish_requests( # First pass: collect requests to remove from queues for req_id in request_ids: request = self.requests.get(req_id) - if request is None: + if request is None or request.is_finished(): # Invalid request ID. continue @@ -1244,11 +1253,13 @@ def make_stats( return None prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() assert prefix_cache_stats is not None + connector_prefix_cache_stats = self._make_connector_prefix_cache_stats() return SchedulerStats( num_running_reqs=len(self.running), num_waiting_reqs=len(self.waiting), kv_cache_usage=self.kv_cache_manager.usage, prefix_cache_stats=prefix_cache_stats, + connector_prefix_cache_stats=connector_prefix_cache_stats, spec_decoding_stats=spec_decoding_stats, num_corrupted_reqs=sum(req.is_output_corrupted for req in self.running), kv_connector_stats=kv_connector_stats.data if kv_connector_stats else None, @@ -1279,6 +1290,25 @@ def shutdown(self) -> None: # KV Connector Related Methods ######################################################################## + def _update_connector_prefix_cache_stats( + self, request: Request, num_external_tokens: int + ) -> None: + if self.connector_prefix_cache_stats is None: + return + + self.connector_prefix_cache_stats.record( + num_tokens=request.num_tokens, + num_hits=num_external_tokens, + preempted=request.num_preemptions > 0, + ) + + def _make_connector_prefix_cache_stats(self) -> PrefixCacheStats | None: + if self.connector_prefix_cache_stats is None: + return None + stats = self.connector_prefix_cache_stats + self.connector_prefix_cache_stats = PrefixCacheStats() + return stats + def get_kv_connector(self) -> KVConnectorBase_V1 | None: return self.connector @@ -1294,8 +1324,17 @@ def _connector_finished( if self.connector is None: return False, None - (block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id) - return self.connector.request_finished(request, block_ids) + block_ids = self.kv_cache_manager.get_block_ids(request.request_id) + + if not supports_hma(self.connector): + # NOTE(Kuntai): We should deprecate this code path after we enforce + # all connectors to support HMA. + # Hybrid memory allocator should be already turned off for this + # code path, but let's double-check here. + assert len(self.kv_cache_config.kv_cache_groups) == 1 + return self.connector.request_finished(request, block_ids[0]) + else: + return self.connector.request_finished(request, block_ids) def _update_waiting_for_remote_kv(self, request: Request) -> bool: """ @@ -1363,14 +1402,8 @@ def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): self.finished_recving_kv_req_ids.add(req_id) for req_id in kv_connector_output.finished_sending or (): logger.debug("Finished sending KV transfer for request %s", req_id) - if req_id not in self.requests: - logger.warning( - "Got finished sending KV transfer for request %s," - "but the request is already freed.", - req_id, - ) - else: - self._free_blocks(self.requests[req_id]) + assert req_id in self.requests + self._free_blocks(self.requests[req_id]) def _update_requests_with_invalid_blocks( self, requests: Iterable[Request], invalid_block_ids: set[int] @@ -1466,7 +1499,7 @@ def _update_requests_with_invalid_blocks( affected_req_ids.add(request.request_id) - return (affected_req_ids, total_affected_tokens) + return affected_req_ids, total_affected_tokens def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: total_requests_to_reschedule = 0 @@ -1488,7 +1521,7 @@ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: total_tokens_to_reschedule += num_tokens_to_reschedule # Mark requests with async KV load failures; they will be rescheduled - # once loading completes + # once loading completes. self.failed_recving_kv_req_ids |= async_affected_req_ids # --- Handle sync KV loads (running requests) --- diff --git a/vllm/v1/core/sched/utils.py b/vllm/v1/core/sched/utils.py index 4f17468d2d58..8af8a7d27806 100644 --- a/vllm/v1/core/sched/utils.py +++ b/vllm/v1/core/sched/utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib -from typing import Optional import torch @@ -41,7 +40,7 @@ def remove_all(lst: list, items_to_remove: set) -> list: def check_stop( - request: Request, max_model_len: int, pooler_output: Optional[torch.Tensor] = None + request: Request, max_model_len: int, pooler_output: torch.Tensor | None = None ) -> bool: if ( request.num_tokens >= max_model_len @@ -58,6 +57,10 @@ def check_stop( sampling_params = request.sampling_params assert sampling_params is not None + + if request.num_output_tokens < sampling_params.min_tokens: + return False + last_token_id = request.output_token_ids[-1] if not sampling_params.ignore_eos and last_token_id == request.eos_token_id: request.status = RequestStatus.FINISHED_STOPPED diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 7984a6ce29df..586034182686 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -3,6 +3,7 @@ import itertools from abc import ABC, abstractmethod from collections import defaultdict +from collections.abc import Sequence from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool @@ -61,7 +62,10 @@ def __init__( self._null_block = block_pool.null_block def get_num_blocks_to_allocate( - self, request_id: str, num_tokens: int, new_computed_blocks: list[KVCacheBlock] + self, + request_id: str, + num_tokens: int, + new_computed_blocks: Sequence[KVCacheBlock], ) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -93,7 +97,7 @@ def get_num_blocks_to_allocate( return num_new_blocks + num_evictable_computed_blocks def save_new_computed_blocks( - self, request_id: str, new_computed_blocks: list[KVCacheBlock] + self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock] ) -> None: """ Add the new computed blocks to the request. @@ -593,7 +597,10 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: return 0 def get_num_blocks_to_allocate( - self, request_id: str, num_tokens: int, new_computed_blocks: list[KVCacheBlock] + self, + request_id: str, + num_tokens: int, + new_computed_blocks: Sequence[KVCacheBlock], ) -> int: # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. @@ -625,7 +632,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager): """Manager for cross-attention KV cache in encoder-decoder models.""" def save_new_computed_blocks( - self, request_id: str, new_computed_blocks: list[KVCacheBlock] + self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock] ) -> None: # We do not cache blocks for cross-attention to be shared between # requests, so `new_computed_blocks` should always be empty. diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index ce4714702869..b480ac78f23c 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from itertools import product from vllm.config import CUDAGraphMode, VllmConfig from vllm.forward_context import BatchDescriptor @@ -44,12 +44,12 @@ def __init__(self, vllm_config: VllmConfig): not_use_piecewise_compilation or self.compilation_config.is_attention_compiled_piecewise() ), ( - "Compilation level should be CompilationLevel.PIECEWISE when " + "Compilation mode should be CompilationMode.VLLM_COMPILE when " "cudagraph_mode piecewise cudagraphs is used, " "and attention should be in splitting_ops or " "inductor splitting should be used. " f"cudagraph_mode={self.cudagraph_mode}, " - f"compilation_level={self.compilation_config.level}, " + f"compilation_mode={self.compilation_config.mode}, " f"splitting_ops={self.compilation_config.splitting_ops}" ) @@ -68,14 +68,27 @@ def initialize_cudagraph_keys( ): # This should be called only after attention backend is initialized. + # LoRA activation cases to specialize the cuda graphs on + if self.vllm_config.lora_config: + if self.compilation_config.cudagraph_specialize_lora: + lora_cases = [True, False] + else: + lora_cases = [True] + else: + lora_cases = [False] + # Note: we create all valid keys for cudagraph here but do not # guarantee all keys would be used. For example, if we allow lazy # capturing in future PR, some keys may never be triggered. if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: - for bs in self.compilation_config.cudagraph_capture_sizes: + for bs, has_lora in product( + self.compilation_config.cudagraph_capture_sizes, lora_cases + ): self.add_cudagraph_key( cudagraph_mode.mixed_mode(), - BatchDescriptor(num_tokens=bs, uniform_decode=False), + BatchDescriptor( + num_tokens=bs, uniform_decode=False, has_lora=has_lora + ), ) # if decode cudagraph mode is FULL, and we don't already have mixed @@ -93,16 +106,18 @@ def initialize_cudagraph_keys( for x in self.compilation_config.cudagraph_capture_sizes if x <= max_num_tokens and x >= uniform_decode_query_len ] - for bs in cudagraph_capture_sizes_for_decode: + for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases): self.add_cudagraph_key( CUDAGraphMode.FULL, - BatchDescriptor(num_tokens=bs, uniform_decode=True), + BatchDescriptor( + num_tokens=bs, uniform_decode=True, has_lora=has_lora + ), ) self.keys_initialized = True def dispatch( self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False - ) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]: + ) -> tuple[CUDAGraphMode, BatchDescriptor | None]: """ Given conditions(e.g.,batch descriptor and if using cascade attention), dispatch to a cudagraph runtime mode and the valid batch descriptor. diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 163c050e559e..e2c1ed7b561c 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -4,7 +4,7 @@ import enum import time from collections.abc import Mapping -from typing import Any, Optional, Union +from typing import Any import msgspec import torch @@ -48,16 +48,16 @@ class EngineCoreRequest( gc=False, ): # type: ignore[call-arg] request_id: str - prompt_token_ids: Optional[list[int]] - mm_features: Optional[list[MultiModalFeatureSpec]] - sampling_params: Optional[SamplingParams] - pooling_params: Optional[PoolingParams] - eos_token_id: Optional[int] + prompt_token_ids: list[int] | None + mm_features: list[MultiModalFeatureSpec] | None + sampling_params: SamplingParams | None + pooling_params: PoolingParams | None + eos_token_id: int | None arrival_time: float - lora_request: Optional[LoRARequest] - cache_salt: Optional[str] - data_parallel_rank: Optional[int] - prompt_embeds: Optional[torch.Tensor] = None + lora_request: LoRARequest | None + cache_salt: str | None + data_parallel_rank: int | None + prompt_embeds: torch.Tensor | None = None # Index of the client, used to ensure outputs are sent back to the same # client for this request when scaling out the front-end. @@ -69,7 +69,7 @@ class EngineCoreRequest( current_wave: int = 0 priority: int = 0 - trace_headers: Optional[Mapping[str, str]] = None + trace_headers: Mapping[str, str] | None = None class EngineCoreEventType(enum.IntEnum): @@ -93,7 +93,7 @@ class EngineCoreEvent(msgspec.Struct): @classmethod def new_event( - cls, event_type: EngineCoreEventType, timestamp: Optional[float] = None + cls, event_type: EngineCoreEventType, timestamp: float | None = None ) -> "EngineCoreEvent": timestamp = time.monotonic() if timestamp is None else timestamp return cls(event_type, timestamp) @@ -108,17 +108,17 @@ class EngineCoreOutput( request_id: str new_token_ids: list[int] - new_logprobs: Optional[LogprobsLists] = None - new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None + new_logprobs: LogprobsLists | None = None + new_prompt_logprobs_tensors: LogprobsTensors | None = None - pooling_output: Optional[torch.Tensor] = None + pooling_output: torch.Tensor | None = None - finish_reason: Optional[FinishReason] = None - stop_reason: Union[int, str, None] = None - events: Optional[list[EngineCoreEvent]] = None - kv_transfer_params: Optional[dict[str, Any]] = None + finish_reason: FinishReason | None = None + stop_reason: int | str | None = None + events: list[EngineCoreEvent] | None = None + kv_transfer_params: dict[str, Any] | None = None - trace_headers: Optional[Mapping[str, str]] = None + trace_headers: Mapping[str, str] | None = None # The number of tokens with prefix cache hits. num_cached_tokens: int = 0 @@ -142,8 +142,8 @@ class UtilityOutput( call_id: int # Non-None implies the call failed, result should be None. - failure_message: Optional[str] = None - result: Optional[UtilityResult] = None + failure_message: str | None = None + result: UtilityResult | None = None class EngineCoreOutputs( @@ -159,18 +159,18 @@ class EngineCoreOutputs( # [num_reqs] outputs: list[EngineCoreOutput] = [] - scheduler_stats: Optional[SchedulerStats] = None + scheduler_stats: SchedulerStats | None = None timestamp: float = 0.0 - utility_output: Optional[UtilityOutput] = None - finished_requests: Optional[set[str]] = None + utility_output: UtilityOutput | None = None + finished_requests: set[str] | None = None # In DP case, used to signal that the current wave of requests # has finished and the engines are paused. - wave_complete: Optional[int] = None + wave_complete: int | None = None # In DP case, used to signal that a request was received for an # "old" wave, so the next wave needs to be started in other engines. - start_wave: Optional[int] = None + start_wave: int | None = None def __post_init__(self): if self.timestamp == 0.0: diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 5be1f833e3f6..62faf590b23f 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -6,39 +6,45 @@ import time from collections.abc import AsyncGenerator, Iterable, Mapping from copy import copy -from typing import Any, Optional, Union +from typing import Any import numpy as np import torch import vllm.envs as envs -from vllm.config import ModelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient from vllm.entrypoints.utils import _validate_truncation_size -from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE from vllm.inputs import PromptType -from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask from vllm.tracing import init_tracer from vllm.transformers_utils.config import maybe_register_config_serialize_by_value -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv, deprecate_kwargs +from vllm.utils import Device, cdiv +from vllm.utils.async_utils import cancel_task_threadsafe +from vllm.utils.collection_utils import as_list +from vllm.utils.func_utils import deprecate_kwargs from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor -from vllm.v1.executor.abstract import Executor -from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager +from vllm.v1.executor import Executor +from vllm.v1.metrics.loggers import ( + StatLoggerFactory, + StatLoggerManager, + load_stat_logger_plugin_factories, +) from vllm.v1.metrics.prometheus import shutdown_prometheus from vllm.v1.metrics.stats import IterationStats @@ -56,8 +62,9 @@ def __init__( use_cached_outputs: bool = False, log_requests: bool = True, start_engine_loop: bool = True, - stat_loggers: Optional[list[StatLoggerFactory]] = None, - client_addresses: Optional[dict[str, str]] = None, + stat_loggers: list[StatLoggerFactory] | None = None, + aggregate_engine_logging: bool = False, + client_addresses: dict[str, str] | None = None, client_count: int = 1, client_index: int = 0, ) -> None: @@ -97,15 +104,28 @@ def __init__( self.observability_config = vllm_config.observability_config self.log_requests = log_requests - self.log_stats = log_stats or (stat_loggers is not None) - if not log_stats and stat_loggers is not None: + custom_stat_loggers = list(stat_loggers or []) + custom_stat_loggers.extend(load_stat_logger_plugin_factories()) + + has_custom_loggers = bool(custom_stat_loggers) + self.log_stats = log_stats or has_custom_loggers + if not log_stats and has_custom_loggers: logger.info( - "AsyncLLM created with log_stats=False and non-empty custom " - "logger list; enabling logging without default stat loggers" + "AsyncLLM created with log_stats=False, " + "but custom stat loggers were found; " + "enabling logging without default stat loggers." ) - # Processor (converts Inputs --> EngineCoreRequests). - self.processor = Processor(vllm_config, mm_registry=mm_registry) + if self.model_config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = init_tokenizer_from_configs(self.model_config) + + self.processor = Processor(self.vllm_config, tokenizer) + self.io_processor = get_io_processor( + self.vllm_config, + self.model_config.io_processor_plugin, + ) # OutputProcessor (converts EngineCoreOutputs --> RequestOutput). self.output_processor = OutputProcessor( @@ -128,18 +148,19 @@ def __init__( ) # Loggers. - self.logger_manager: Optional[StatLoggerManager] = None + self.logger_manager: StatLoggerManager | None = None if self.log_stats: self.logger_manager = StatLoggerManager( vllm_config=vllm_config, engine_idxs=self.engine_core.engine_ranks_managed, - custom_stat_loggers=stat_loggers, + custom_stat_loggers=custom_stat_loggers, enable_default_loggers=log_stats, client_count=client_count, + aggregate_engine_logging=aggregate_engine_logging, ) self.logger_manager.log_engine_initialized() - self.output_handler: Optional[asyncio.Task] = None + self.output_handler: asyncio.Task | None = None try: # Start output handler eagerly if we are in the asyncio eventloop. asyncio.get_running_loop() @@ -177,10 +198,11 @@ def from_vllm_config( vllm_config: VllmConfig, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[list[StatLoggerFactory]] = None, + stat_loggers: list[StatLoggerFactory] | None = None, enable_log_requests: bool = False, + aggregate_engine_logging: bool = False, disable_log_stats: bool = False, - client_addresses: Optional[dict[str, str]] = None, + client_addresses: dict[str, str] | None = None, client_count: int = 1, client_index: int = 0, disable_log_requests: bool = True, # Deprecated, will be removed @@ -201,6 +223,7 @@ def from_vllm_config( stat_loggers=stat_loggers, log_requests=enable_log_requests, log_stats=not disable_log_stats, + aggregate_engine_logging=aggregate_engine_logging, usage_context=usage_context, client_addresses=client_addresses, client_count=client_count, @@ -213,7 +236,7 @@ def from_engine_args( engine_args: AsyncEngineArgs, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[list[StatLoggerFactory]] = None, + stat_loggers: list[StatLoggerFactory] | None = None, ) -> "AsyncLLM": """Create an AsyncLLM from the EngineArgs.""" @@ -245,25 +268,21 @@ def shutdown(self): cancel_task_threadsafe(getattr(self, "output_handler", None)) - @property - def tokenizer(self) -> Optional[AnyTokenizer]: - return self.processor.tokenizer - async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return await self.engine_core.get_supported_tasks_async() async def add_request( self, request_id: str, - prompt: Union[EngineCoreRequest, PromptType], - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - tokenization_kwargs: Optional[dict[str, Any]] = None, - trace_headers: Optional[Mapping[str, str]] = None, + prompt: EngineCoreRequest | PromptType, + params: SamplingParams | PoolingParams, + arrival_time: float | None = None, + lora_request: LoRARequest | None = None, + tokenization_kwargs: dict[str, Any] | None = None, + trace_headers: Mapping[str, str] | None = None, priority: int = 0, - data_parallel_rank: Optional[int] = None, - prompt_text: Optional[str] = None, + data_parallel_rank: int | None = None, + prompt_text: str | None = None, ) -> RequestOutputCollector: """Add new request to the AsyncLLM.""" @@ -321,8 +340,8 @@ async def add_request( async def _add_request( self, request: EngineCoreRequest, - prompt: Optional[str], - parent_req: Optional[ParentRequest], + prompt: str | None, + parent_req: ParentRequest | None, index: int, queue: RequestOutputCollector, ): @@ -342,16 +361,16 @@ async def _add_request( # re-multiplexed in the API server anyhow. async def generate( self, - prompt: Union[EngineCoreRequest, PromptType], + prompt: EngineCoreRequest | PromptType, sampling_params: SamplingParams, request_id: str, *, - prompt_text: Optional[str] = None, - lora_request: Optional[LoRARequest] = None, - tokenization_kwargs: Optional[dict[str, Any]] = None, - trace_headers: Optional[Mapping[str, str]] = None, + prompt_text: str | None = None, + lora_request: LoRARequest | None = None, + tokenization_kwargs: dict[str, Any] | None = None, + trace_headers: Mapping[str, str] | None = None, priority: int = 0, - data_parallel_rank: Optional[int] = None, + data_parallel_rank: int | None = None, ) -> AsyncGenerator[RequestOutput, None]: """ Main function called by the API server to kick off a request @@ -459,6 +478,7 @@ def _run_output_handler(self): output_processor = self.output_processor log_stats = self.log_stats logger_manager = self.logger_manager + processor = self.processor async def output_handler(): try: @@ -474,12 +494,12 @@ async def output_handler(): # Split outputs into chunks of at most # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the # event loop for too long. - if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: + if num_outputs <= envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: slices = (outputs.outputs,) else: slices = np.array_split( outputs.outputs, - cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE), + cdiv(num_outputs, envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE), ) for i, outputs_slice in enumerate(slices): @@ -507,6 +527,7 @@ async def output_handler(): engine_idx=outputs.engine_index, scheduler_stats=outputs.scheduler_stats, iteration_stats=iteration_stats, + mm_cache_stats=processor.stat_mm_cache(), ) except Exception as e: logger.exception("AsyncLLM output_handler failed.") @@ -514,7 +535,7 @@ async def output_handler(): self.output_handler = asyncio.create_task(output_handler()) - async def abort(self, request_id: Union[str, Iterable[str]]) -> None: + async def abort(self, request_id: str | Iterable[str]) -> None: """Abort RequestId in OutputProcessor and EngineCore.""" request_ids = ( @@ -531,11 +552,11 @@ async def encode( prompt: PromptType, pooling_params: PoolingParams, request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, + lora_request: LoRARequest | None = None, + trace_headers: Mapping[str, str] | None = None, priority: int = 0, - truncate_prompt_tokens: Optional[int] = None, - tokenization_kwargs: Optional[dict[str, Any]] = None, + truncate_prompt_tokens: int | None = None, + tokenization_kwargs: dict[str, Any] | None = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """ Main function called by the API server to kick off a request @@ -615,14 +636,13 @@ async def encode( logger.info("Request %s failed.", request_id) raise EngineGenerateError() from e - async def get_vllm_config(self) -> VllmConfig: - return self.vllm_config - - async def get_model_config(self) -> ModelConfig: - return self.model_config + @property + def tokenizer(self) -> AnyTokenizer | None: + return self.processor.tokenizer - async def get_input_preprocessor(self) -> InputPreprocessor: - return self.processor.input_preprocessor + @tokenizer.setter + def tokenizer(self, tokenizer: AnyTokenizer | None) -> None: + self.processor.tokenizer = tokenizer async def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: @@ -657,10 +677,10 @@ async def stop_profile(self) -> None: await asyncio.gather(*coros) async def reset_mm_cache(self) -> None: - self.processor.clear_cache() + self.processor.clear_mm_cache() await self.engine_core.reset_mm_cache_async() - async def reset_prefix_cache(self, device: Optional[Device] = None) -> None: + async def reset_prefix_cache(self, device: Device | None = None) -> None: if device == Device.CPU: raise ValueError("Not supported on CPU.") await self.engine_core.reset_prefix_cache_async() @@ -669,7 +689,7 @@ async def sleep(self, level: int = 1) -> None: await self.reset_prefix_cache() await self.engine_core.sleep_async(level) - async def wake_up(self, tags: Optional[list[str]] = None) -> None: + async def wake_up(self, tags: list[str] | None = None) -> None: await self.engine_core.wake_up_async(tags) async def is_sleeping(self) -> bool: @@ -694,9 +714,9 @@ async def pin_lora(self, lora_id: int) -> bool: async def collective_rpc( self, method: str, - timeout: Optional[float] = None, + timeout: float | None = None, args: tuple = (), - kwargs: Optional[dict] = None, + kwargs: dict | None = None, ): """ Perform a collective RPC call to the given path. diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 9bb08e6db7be..39d8655ff858 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -4,14 +4,15 @@ import multiprocessing import time import weakref -from typing import Optional import msgspec.msgpack import zmq from vllm.config import ParallelConfig from vllm.logger import init_logger -from vllm.utils import get_mp_context, make_zmq_socket, set_process_title +from vllm.utils import get_mp_context +from vllm.utils.network_utils import make_zmq_socket +from vllm.utils.system_utils import set_process_title from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType from vllm.v1.serial_utils import MsgpackDecoder from vllm.v1.utils import get_engine_client_zmq_addr, shutdown @@ -155,7 +156,7 @@ def process_input_socket( stats_changed = False last_stats_step = -1 last_stats_wave = -1 - last_step_counts: Optional[list[list[int]]] = None + last_step_counts: list[list[int]] | None = None with ( make_zmq_socket( @@ -360,7 +361,7 @@ def process_input_socket( @staticmethod def _send_start_wave( - socket: zmq.Socket, wave: int, exclude_engine_index: Optional[int] + socket: zmq.Socket, wave: int, exclude_engine_index: int | None ): """Broadcast the START_DP_WAVE message to all the engines. It includes the current wave number and index of engine which diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 93f7fd5725bd..85cab32ebfb8 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -7,18 +7,19 @@ import threading import time from collections import deque -from collections.abc import Generator +from collections.abc import Callable, Generator from concurrent.futures import Future from contextlib import ExitStack, contextmanager from inspect import isclass, signature from logging import DEBUG -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, TypeVar import msgspec import zmq from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import stateless_destroy_torch_distributed_process_group +from vllm.envs import enable_envs_cache from vllm.logger import init_logger from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest @@ -26,14 +27,11 @@ from vllm.multimodal.cache import engine_receiver_cache_from_config from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import maybe_register_config_serialize_by_value -from vllm.utils import ( - decorate_logs, - get_hash_fn_by_name, - make_zmq_socket, - resolve_obj_by_qualname, - set_process_title, -) from vllm.utils.gc_utils import maybe_attach_gc_debug_callback +from vllm.utils.hashing import get_hash_fn_by_name +from vllm.utils.import_utils import resolve_obj_by_qualname +from vllm.utils.network_utils import make_zmq_socket +from vllm.utils.system_utils import decorate_logs, set_process_title from vllm.v1.core.kv_cache_utils import ( BlockHash, generate_scheduler_kv_cache_config, @@ -58,7 +56,7 @@ EngineZmqAddresses, get_device_indices, ) -from vllm.v1.executor.abstract import Executor +from vllm.v1.executor import Executor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput @@ -83,7 +81,7 @@ def __init__( vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, - executor_fail_callback: Optional[Callable] = None, + executor_fail_callback: Callable | None = None, ): # plugins need to be loaded at the engine/scheduler level too from vllm.plugins import load_general_plugins @@ -91,11 +89,12 @@ def __init__( load_general_plugins() self.vllm_config = vllm_config - logger.info( - "Initializing a V1 LLM engine (v%s) with config: %s", - VLLM_VERSION, - vllm_config, - ) + if vllm_config.parallel_config.data_parallel_rank == 0: + logger.info( + "Initializing a V1 LLM engine (v%s) with config: %s", + VLLM_VERSION, + vllm_config, + ) self.log_stats = log_stats @@ -142,18 +141,22 @@ def __init__( logger.info("Disabling chunked prefill for model without KVCache") vllm_config.scheduler_config.chunked_prefill_enabled = False + scheduler_block_size = ( + vllm_config.cache_config.block_size + * vllm_config.parallel_config.decode_context_parallel_size + ) + self.scheduler: SchedulerInterface = Scheduler( vllm_config=vllm_config, kv_cache_config=kv_cache_config, structured_output_manager=self.structured_output_manager, include_finished_set=vllm_config.parallel_config.data_parallel_size > 1, log_stats=self.log_stats, + block_size=scheduler_block_size, ) self.use_spec_decode = vllm_config.speculative_config is not None if self.scheduler.connector is not None: # type: ignore - self.model_executor.init_kv_output_aggregator( - self.scheduler.connector.get_finished_count() # type: ignore - ) + self.model_executor.init_kv_output_aggregator(self.scheduler.connector) # type: ignore self.mm_registry = mm_registry = MULTIMODAL_REGISTRY self.mm_receiver_cache = engine_receiver_cache_from_config( @@ -165,26 +168,25 @@ def __init__( # schedule and execute batches, and is required by pipeline parallelism # to eliminate pipeline bubbles. self.batch_queue_size = self.model_executor.max_concurrent_batches - self.batch_queue: Optional[ - deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] - ] = None + self.batch_queue: ( + deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] | None + ) = None if self.batch_queue_size > 1: logger.info("Batch queue is enabled with size %d", self.batch_queue_size) self.batch_queue = deque(maxlen=self.batch_queue_size) - self.request_block_hasher: Optional[Callable[[Request], list[BlockHash]]] = None + self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None if ( self.vllm_config.cache_config.enable_prefix_caching or self.scheduler.get_kv_connector() is not None ): - block_size = vllm_config.cache_config.block_size caching_hash_fn = get_hash_fn_by_name( vllm_config.cache_config.prefix_caching_hash_algo ) init_none_hash(caching_hash_fn) self.request_block_hasher = get_request_block_hasher( - block_size, caching_hash_fn + scheduler_block_size, caching_hash_fn ) self.step_fn = ( @@ -232,9 +234,10 @@ def _initialize_kv_caches( self.model_executor.initialize_from_config(kv_cache_configs) elapsed = time.time() - start - logger.info( + logger.info_once( ("init engine (profile, create kv cache, warmup model) took %.2f seconds"), elapsed, + scope="local", ) return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config @@ -282,14 +285,11 @@ def abort_requests(self, request_ids: list[str]): # (i.e. client-aborted vs stop criteria met). self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) - def execute_model_with_error_logging( - self, - model_fn: Callable[[SchedulerOutput], ModelRunnerOutput], - scheduler_output: SchedulerOutput, - ) -> ModelRunnerOutput: + @contextmanager + def log_error_detail(self, scheduler_output: SchedulerOutput): """Execute the model and log detailed info on failure.""" try: - return model_fn(scheduler_output) + yield except Exception as err: # We do not want to catch BaseException here since we're only # interested in dumping info when the exception is due to an @@ -313,15 +313,15 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: if not self.scheduler.has_requests(): return {}, False scheduler_output = self.scheduler.schedule() - model_output = self.execute_model_with_error_logging( - self.model_executor.execute_model, # type: ignore - scheduler_output, - ) + + with self.log_error_detail(scheduler_output): + model_output = self.model_executor.execute_model(scheduler_output) + engine_core_outputs = self.scheduler.update_from_output( scheduler_output, model_output - ) # type: ignore + ) - return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0) + return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0 def post_step(self, model_executed: bool) -> None: if self.use_spec_decode and model_executed: @@ -332,7 +332,7 @@ def post_step(self, model_executed: bool) -> None: def step_with_batch_queue( self, - ) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]: + ) -> tuple[dict[int, EngineCoreOutputs] | None, bool]: """Schedule and execute batches with the batch queue. Note that if nothing to output in this step, None is returned. @@ -358,7 +358,7 @@ def step_with_batch_queue( if self.scheduler.has_requests(): scheduler_output = self.scheduler.schedule() future = self.model_executor.execute_model(scheduler_output, non_block=True) - batch_queue.appendleft((future, scheduler_output)) # type: ignore[arg-type] + batch_queue.appendleft((future, scheduler_output)) model_executed = scheduler_output.total_num_scheduled_tokens > 0 if ( @@ -378,14 +378,12 @@ def step_with_batch_queue( # Block until the next result is available. future, scheduler_output = batch_queue.pop() - model_output = self.execute_model_with_error_logging( - lambda _: future.result(), scheduler_output - ) + with self.log_error_detail(scheduler_output): + model_output = future.result() engine_core_outputs = self.scheduler.update_from_output( scheduler_output, model_output ) - return engine_core_outputs, model_executed def shutdown(self): @@ -400,23 +398,26 @@ def profile(self, is_start: bool = True): def reset_mm_cache(self): # NOTE: Since this is mainly for debugging, we don't attempt to - # re-sync the internal caches (P0 processor, P0 mirror, P1 mirror) + # re-sync the internal caches (P0 sender, P1 receiver) if self.scheduler.has_unfinished_requests(): logger.warning( "Resetting the multi-modal cache when requests are " "in progress may lead to desynced internal caches." ) + # The cache either exists in EngineCore or WorkerWrapperBase if self.mm_receiver_cache is not None: self.mm_receiver_cache.clear_cache() + self.model_executor.reset_mm_cache() + def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() def sleep(self, level: int = 1): self.model_executor.sleep(level) - def wake_up(self, tags: Optional[list[str]] = None): + def wake_up(self, tags: list[str] | None = None): self.model_executor.wake_up(tags) def is_sleeping(self) -> bool: @@ -440,8 +441,8 @@ def pin_lora(self, lora_id: int) -> bool: def save_sharded_state( self, path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, + pattern: str | None = None, + max_size: int | None = None, ) -> None: self.model_executor.save_sharded_state( path=path, pattern=pattern, max_size=max_size @@ -449,21 +450,13 @@ def save_sharded_state( def collective_rpc( self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, + method: str | Callable[..., _R], + timeout: float | None = None, args: tuple = (), - kwargs: Optional[dict[str, Any]] = None, + kwargs: dict[str, Any] | None = None, ) -> list[_R]: return self.model_executor.collective_rpc(method, timeout, args, kwargs) - def save_tensorized_model( - self, - tensorizer_config, - ) -> None: - self.model_executor.save_tensorized_model( - tensorizer_config=tensorizer_config, - ) - def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]: """Preprocess the request. @@ -501,11 +494,11 @@ def __init__( handshake_address: str, executor_class: type[Executor], log_stats: bool, - client_handshake_address: Optional[str] = None, + client_handshake_address: str | None = None, engine_index: int = 0, ): self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() - self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], bytes]]() + self.output_queue = queue.Queue[tuple[int, EngineCoreOutputs] | bytes]() executor_fail_callback = lambda: self.input_queue.put_nowait( (EngineCoreRequestType.EXECUTOR_FAILED, b"") ) @@ -591,6 +584,10 @@ def __init__( # If enable, attach GC debugger after static variable freeze. maybe_attach_gc_debug_callback() + # Enable environment variable cache (e.g. assume no more + # environment variable overrides after this point) + enable_envs_cache() + @contextmanager def _perform_handshakes( self, @@ -598,7 +595,7 @@ def _perform_handshakes( identity: bytes, local_client: bool, vllm_config: VllmConfig, - client_handshake_address: Optional[str], + client_handshake_address: str | None, ) -> Generator[EngineZmqAddresses, None, None]: """ Perform startup handshakes. @@ -659,7 +656,7 @@ def _perform_handshake( local_client: bool, headless: bool, vllm_config: VllmConfig, - parallel_config_to_update: Optional[ParallelConfig] = None, + parallel_config_to_update: ParallelConfig | None = None, ) -> Generator[EngineZmqAddresses, None, None]: with make_zmq_socket( ctx, @@ -702,7 +699,7 @@ def startup_handshake( handshake_socket: zmq.Socket, local_client: bool, headless: bool, - parallel_config: Optional[ParallelConfig] = None, + parallel_config: ParallelConfig | None = None, ) -> EngineZmqAddresses: # Send registration message. handshake_socket.send( @@ -716,7 +713,7 @@ def startup_handshake( ) # Receive initialization message. - logger.info("Waiting for init message from front-end.") + logger.debug("Waiting for init message from front-end.") if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000): raise RuntimeError( "Did not receive response from front-end " @@ -757,7 +754,7 @@ def signal_handler(signum, frame): signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) - engine_core: Optional[EngineCoreProc] = None + engine_core: EngineCoreProc | None = None try: parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config if parallel_config.data_parallel_size > 1 or dp_rank > 0: @@ -903,7 +900,7 @@ def _send_engine_dead(self): def process_input_sockets( self, input_addresses: list[str], - coord_input_address: Optional[str], + coord_input_address: str | None, identity: bytes, ready_event: threading.Event, ): @@ -972,7 +969,7 @@ def process_input_sockets( def process_output_sockets( self, output_paths: list[str], - coord_output_path: Optional[str], + coord_output_path: str | None, engine_index: int, ): """Output socket IO thread.""" @@ -1051,7 +1048,7 @@ def __init__( handshake_address: str, executor_class: type[Executor], log_stats: bool, - client_handshake_address: Optional[str] = None, + client_handshake_address: str | None = None, ): # Counts forward-passes of the model so that we can synchronize # finished with DP peers every N steps. @@ -1324,7 +1321,7 @@ def _perform_handshakes( identity: bytes, local_client: bool, vllm_config: VllmConfig, - client_handshake_address: Optional[str], + client_handshake_address: str | None, ): """ For Ray, we don't need to actually perform handshake. diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 27283411eada..7b554ca991b9 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -9,11 +9,11 @@ import weakref from abc import ABC, abstractmethod from collections import defaultdict, deque -from collections.abc import Awaitable, Sequence +from collections.abc import Awaitable, Callable, Sequence from concurrent.futures import Future from dataclasses import dataclass from threading import Thread -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, TypeAlias, TypeVar import msgspec.msgpack import zmq @@ -23,11 +23,11 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.tasks import SupportedTask -from vllm.utils import ( +from vllm.utils.async_utils import in_loop +from vllm.utils.network_utils import ( close_sockets, get_open_port, get_open_zmq_inproc_path, - in_loop, make_zmq_socket, ) from vllm.v1.engine import ( @@ -46,12 +46,12 @@ CoreEngineProcManager, launch_core_engines, ) -from vllm.v1.executor.abstract import Executor +from vllm.v1.executor import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr logger = init_logger(__name__) -AnyFuture = Union[asyncio.Future[Any], Future[Any]] +AnyFuture: TypeAlias = asyncio.Future[Any] | Future[Any] _R = TypeVar("_R") # Return type for collective_rpc @@ -99,7 +99,7 @@ def make_async_mp_client( vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, - client_addresses: Optional[dict[str, str]] = None, + client_addresses: dict[str, str] | None = None, client_count: int = 1, client_index: int = 0, ) -> "MPClient": @@ -144,7 +144,7 @@ def reset_prefix_cache(self) -> None: def sleep(self, level: int = 1) -> None: raise NotImplementedError - def wake_up(self, tags: Optional[list[str]] = None) -> None: + def wake_up(self, tags: list[str] | None = None) -> None: raise NotImplementedError def is_sleeping(self) -> bool: @@ -172,16 +172,16 @@ def pin_lora(self, lora_id: int) -> bool: raise NotImplementedError def save_sharded_state( - self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + self, path: str, pattern: str | None = None, max_size: int | None = None ) -> None: raise NotImplementedError def collective_rpc( self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, + method: str | Callable[..., _R], + timeout: float | None = None, args: tuple = (), - kwargs: Optional[dict[str, Any]] = None, + kwargs: dict[str, Any] | None = None, ) -> list[_R]: raise NotImplementedError @@ -214,7 +214,7 @@ async def reset_prefix_cache_async(self) -> None: async def sleep_async(self, level: int = 1) -> None: raise NotImplementedError - async def wake_up_async(self, tags: Optional[list[str]] = None) -> None: + async def wake_up_async(self, tags: list[str] | None = None) -> None: raise NotImplementedError async def is_sleeping_async(self) -> bool: @@ -236,16 +236,16 @@ async def pin_lora_async(self, lora_id: int) -> bool: raise NotImplementedError async def save_sharded_state_async( - self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + self, path: str, pattern: str | None = None, max_size: int | None = None ) -> None: raise NotImplementedError async def collective_rpc_async( self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, + method: str | Callable[..., _R], + timeout: float | None = None, args: tuple = (), - kwargs: Optional[dict[str, Any]] = None, + kwargs: dict[str, Any] | None = None, ) -> list[_R]: raise NotImplementedError @@ -293,7 +293,7 @@ def reset_prefix_cache(self) -> None: def sleep(self, level: int = 1) -> None: self.engine_core.sleep(level) - def wake_up(self, tags: Optional[list[str]] = None) -> None: + def wake_up(self, tags: list[str] | None = None) -> None: self.engine_core.wake_up(tags) def is_sleeping(self) -> bool: @@ -315,16 +315,16 @@ def pin_lora(self, lora_id: int) -> bool: return self.engine_core.pin_lora(lora_id) def save_sharded_state( - self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + self, path: str, pattern: str | None = None, max_size: int | None = None ) -> None: self.engine_core.save_sharded_state(path, pattern, max_size) def collective_rpc( self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, + method: str | Callable[..., _R], + timeout: float | None = None, args: tuple = (), - kwargs: Optional[dict[str, Any]] = None, + kwargs: dict[str, Any] | None = None, ) -> list[_R]: return self.engine_core.collective_rpc(method, timeout, args, kwargs) @@ -340,18 +340,16 @@ class BackgroundResources: ctx: zmq.Context # If CoreEngineProcManager, it manages local engines; # if CoreEngineActorManager, it manages all engines. - engine_manager: Optional[Union[CoreEngineProcManager, CoreEngineActorManager]] = ( - None - ) - coordinator: Optional[DPCoordinator] = None - output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None - input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None - first_req_send_socket: Optional[zmq.asyncio.Socket] = None - first_req_rcv_socket: Optional[zmq.asyncio.Socket] = None - stats_update_socket: Optional[zmq.asyncio.Socket] = None - output_queue_task: Optional[asyncio.Task] = None - stats_update_task: Optional[asyncio.Task] = None - shutdown_path: Optional[str] = None + engine_manager: CoreEngineProcManager | CoreEngineActorManager | None = None + coordinator: DPCoordinator | None = None + output_socket: zmq.Socket | zmq.asyncio.Socket | None = None + input_socket: zmq.Socket | zmq.asyncio.Socket | None = None + first_req_send_socket: zmq.asyncio.Socket | None = None + first_req_rcv_socket: zmq.asyncio.Socket | None = None + stats_update_socket: zmq.asyncio.Socket | None = None + output_queue_task: asyncio.Task | None = None + stats_update_task: asyncio.Task | None = None + shutdown_path: str | None = None # Set if any of the engines are dead. Here so that the output # processing threads can access it without holding a ref to the client. @@ -438,7 +436,7 @@ def __init__( vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, - client_addresses: Optional[dict[str, str]] = None, + client_addresses: dict[str, str] | None = None, ): self.vllm_config = vllm_config # Serialization setup. @@ -459,7 +457,7 @@ def __init__( # State used for data parallel. self.engines_running = False - self.stats_update_address: Optional[str] = None + self.stats_update_address: str | None = None if client_addresses: # Engines are managed externally to this client. input_address = client_addresses["input_address"] @@ -646,7 +644,7 @@ def __init__( ) self.is_dp = self.vllm_config.parallel_config.data_parallel_size > 1 - self.outputs_queue = queue.Queue[Union[EngineCoreOutputs, Exception]]() + self.outputs_queue = queue.Queue[EngineCoreOutputs | Exception]() # Ensure that the outputs socket processing thread does not have # a ref to the client which prevents gc. @@ -770,7 +768,7 @@ def pin_lora(self, lora_id: int) -> bool: def sleep(self, level: int = 1) -> None: self.call_utility("sleep", level) - def wake_up(self, tags: Optional[list[str]] = None) -> None: + def wake_up(self, tags: list[str] | None = None) -> None: self.call_utility("wake_up", tags) def is_sleeping(self) -> bool: @@ -781,15 +779,15 @@ def execute_dummy_batch(self) -> None: def collective_rpc( self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, + method: str | Callable[..., _R], + timeout: float | None = None, args: tuple = (), - kwargs: Optional[dict[str, Any]] = None, + kwargs: dict[str, Any] | None = None, ) -> list[_R]: return self.call_utility("collective_rpc", method, timeout, args, kwargs) def save_sharded_state( - self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + self, path: str, pattern: str | None = None, max_size: int | None = None ) -> None: self.call_utility("save_sharded_state", path, pattern, max_size) @@ -802,7 +800,7 @@ def __init__( vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, - client_addresses: Optional[dict[str, str]] = None, + client_addresses: dict[str, str] | None = None, client_count: int = 1, client_index: int = 0, ): @@ -816,7 +814,7 @@ def __init__( self.client_count = client_count self.client_index = client_index - self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, Exception]]() + self.outputs_queue = asyncio.Queue[EngineCoreOutputs | Exception]() try: # If we are running in an asyncio event loop, start the queue task. # Otherwise, it will be started lazily. If it is not started here, @@ -837,9 +835,9 @@ def _ensure_output_queue_task(self): decoder = self.decoder utility_results = self.utility_results outputs_queue = self.outputs_queue - output_handler: Optional[ - Callable[[AsyncMPClient, EngineCoreOutputs], Awaitable[None]] - ] = getattr(self.__class__, "process_engine_outputs", None) + output_handler: ( + Callable[[AsyncMPClient, EngineCoreOutputs], Awaitable[None]] | None + ) = getattr(self.__class__, "process_engine_outputs", None) _self_ref = weakref.ref(self) if output_handler else None output_socket = resources.output_socket assert output_socket is not None @@ -888,7 +886,7 @@ def _send_input( self, request_type: EngineCoreRequestType, request: Any, - engine: Optional[EngineIdentity] = None, + engine: EngineIdentity | None = None, ) -> Awaitable[Any]: if engine is None: engine = self.core_engine @@ -962,7 +960,7 @@ async def reset_prefix_cache_async(self) -> None: async def sleep_async(self, level: int = 1) -> None: await self.call_utility_async("sleep", level) - async def wake_up_async(self, tags: Optional[list[str]] = None) -> None: + async def wake_up_async(self, tags: list[str] | None = None) -> None: await self.call_utility_async("wake_up", tags) async def is_sleeping_async(self) -> bool: @@ -984,16 +982,16 @@ async def pin_lora_async(self, lora_id: int) -> bool: return await self.call_utility_async("pin_lora", lora_id) async def save_sharded_state_async( - self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + self, path: str, pattern: str | None = None, max_size: int | None = None ) -> None: await self.call_utility_async("save_sharded_state", path, pattern, max_size) async def collective_rpc_async( self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, + method: str | Callable[..., _R], + timeout: float | None = None, args: tuple = (), - kwargs: Optional[dict[str, Any]] = None, + kwargs: dict[str, Any] | None = None, ) -> list[_R]: return await self.call_utility_async( "collective_rpc", method, timeout, args, kwargs @@ -1009,7 +1007,7 @@ def __init__( vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, - client_addresses: Optional[dict[str, str]] = None, + client_addresses: dict[str, str] | None = None, client_count: int = 1, client_index: int = 0, ): @@ -1166,7 +1164,7 @@ def __init__( vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, - client_addresses: Optional[dict[str, str]] = None, + client_addresses: dict[str, str] | None = None, client_count: int = 1, client_index: int = 0, ): diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 9d1d7558b1ed..5f66e36893bf 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Optional import tokenizers from packaging import version @@ -36,7 +35,7 @@ def __init__(self): def output_token_ids(self) -> list[int]: return self.token_ids - def update(self, new_token_ids: list[int], stop_terminated: bool) -> Optional[str]: + def update(self, new_token_ids: list[int], stop_terminated: bool) -> str | None: self.token_ids.extend(new_token_ids) return None @@ -46,7 +45,7 @@ def get_next_output_text(self, finished: bool, delta: bool) -> str: @classmethod def from_new_request( cls, - tokenizer: Optional[AnyTokenizer], + tokenizer: AnyTokenizer | None, request: EngineCoreRequest, ) -> "IncrementalDetokenizer": assert request.sampling_params is not None @@ -85,7 +84,7 @@ def __init__(self, request: EngineCoreRequest): # Generation data self.output_text = "" - def update(self, new_token_ids: list[int], stop_terminated: bool) -> Optional[str]: + def update(self, new_token_ids: list[int], stop_terminated: bool) -> str | None: """ Update RequestState for the request_id by: 1) Detokenize the new token ids incrementally. @@ -224,13 +223,13 @@ def decode_next(self, next_token_id: int) -> str: return token or "" - def _protected_step(self, next_token_id: int) -> Optional[str]: + def _protected_step(self, next_token_id: int) -> str | None: try: token = self.stream.step(self.tokenizer, next_token_id) - except OverflowError: + except (OverflowError, TypeError): # Handle rare observed overflow, still to be diagnosed. # See https://github.com/vllm-project/vllm/issues/21951. - logger.exception("Encountered invalid token id: %d", next_token_id) + logger.exception("Encountered invalid token id: %r", next_token_id) token = None except Exception as e: if not str(e).startswith(INVALID_PREFIX_ERR_MSG): @@ -312,7 +311,7 @@ def check_stop_strings( new_char_count: int, stop: list[str], include_in_output: bool, -) -> Optional[tuple[str, int]]: +) -> tuple[str, int] | None: """Check if any stop strings are matched and truncate sequence output text accordingly. diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 701a62580562..9d69ed93ed37 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -2,9 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time -from collections.abc import Mapping +from collections.abc import Callable, Mapping from copy import copy -from typing import Any, Callable, Optional, Union +from typing import Any import torch.nn as nn from typing_extensions import TypeVar @@ -19,11 +19,12 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask from vllm.tracing import init_tracer -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext from vllm.utils import Device from vllm.v1.engine import EngineCoreRequest @@ -31,7 +32,7 @@ from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor -from vllm.v1.executor.abstract import Executor +from vllm.v1.executor import Executor from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager from vllm.v1.metrics.reader import Metric, get_metrics_snapshot from vllm.v1.metrics.stats import IterationStats @@ -50,8 +51,9 @@ def __init__( vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, + aggregate_engine_logging: bool = False, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[list[StatLoggerFactory]] = None, + stat_loggers: list[StatLoggerFactory] | None = None, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, use_cached_outputs: bool = False, multiprocess_mode: bool = False, @@ -95,8 +97,16 @@ def __init__( self.dp_group = None self.should_execute_dummy_batch = False - # Processor (convert Inputs --> EngineCoreRequests) - self.processor = Processor(vllm_config, mm_registry=mm_registry) + if self.model_config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = init_tokenizer_from_configs(self.model_config) + + self.processor = Processor(self.vllm_config, tokenizer) + self.io_processor = get_io_processor( + self.vllm_config, + self.model_config.io_processor_plugin, + ) # OutputProcessor (convert EngineCoreOutputs --> RequestOutput). self.output_processor = OutputProcessor( @@ -117,12 +127,13 @@ def __init__( log_stats=self.log_stats, ) - self.logger_manager: Optional[StatLoggerManager] = None + self.logger_manager: StatLoggerManager | None = None if self.log_stats: self.logger_manager = StatLoggerManager( vllm_config=vllm_config, custom_stat_loggers=stat_loggers, enable_default_loggers=log_stats, + aggregate_engine_logging=aggregate_engine_logging, ) self.logger_manager.log_engine_initialized() @@ -143,7 +154,7 @@ def from_vllm_config( cls, vllm_config: VllmConfig, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[list[StatLoggerFactory]] = None, + stat_loggers: list[StatLoggerFactory] | None = None, disable_log_stats: bool = False, ) -> "LLMEngine": return cls( @@ -160,7 +171,7 @@ def from_engine_args( cls, engine_args: EngineArgs, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[list[StatLoggerFactory]] = None, + stat_loggers: list[StatLoggerFactory] | None = None, enable_multiprocessing: bool = False, ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" @@ -204,14 +215,6 @@ def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool: def validate_outputs(cls, outputs, output_type): return outputs - @property - def tokenizer(self) -> Optional[AnyTokenizer]: - return self.processor.tokenizer - - @tokenizer.setter - def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None: - self.processor.tokenizer = tokenizer - def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return self.engine_core.get_supported_tasks() @@ -224,14 +227,14 @@ def abort_request(self, request_ids: list[str]) -> None: def add_request( self, request_id: str, - prompt: Union[EngineCoreRequest, PromptType], - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - tokenization_kwargs: Optional[dict[str, Any]] = None, - trace_headers: Optional[Mapping[str, str]] = None, + prompt: EngineCoreRequest | PromptType, + params: SamplingParams | PoolingParams, + arrival_time: float | None = None, + lora_request: LoRARequest | None = None, + tokenization_kwargs: dict[str, Any] | None = None, + trace_headers: Mapping[str, str] | None = None, priority: int = 0, - prompt_text: Optional[str] = None, + prompt_text: str | None = None, ) -> None: # Validate the request_id type. if not isinstance(request_id, str): @@ -282,7 +285,7 @@ def add_request( # Add the request to EngineCore. self.engine_core.add_request(child_request) - def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]: + def step(self) -> list[RequestOutput] | list[PoolingRequestOutput]: if self.should_execute_dummy_batch: self.should_execute_dummy_batch = False self.engine_core.execute_dummy_batch() @@ -305,20 +308,16 @@ def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]: # 4) Record stats if self.logger_manager is not None: assert outputs.scheduler_stats is not None + self.logger_manager.record( scheduler_stats=outputs.scheduler_stats, iteration_stats=iteration_stats, + mm_cache_stats=self.processor.stat_mm_cache(), ) self.do_log_stats_with_interval() return processed_outputs.request_outputs - def get_vllm_config(self): - return self.vllm_config - - def get_model_config(self): - return self.model_config - def start_profile(self): self.engine_core.profile(True) @@ -326,16 +325,16 @@ def stop_profile(self): self.engine_core.profile(False) def reset_mm_cache(self): - self.processor.clear_cache() + self.processor.clear_mm_cache() self.engine_core.reset_mm_cache() - def reset_prefix_cache(self, device: Optional[Device] = None): + def reset_prefix_cache(self, device: Device | None = None): self.engine_core.reset_prefix_cache() def sleep(self, level: int = 1): self.engine_core.sleep(level) - def wake_up(self, tags: Optional[list[str]] = None): + def wake_up(self, tags: list[str] | None = None): self.engine_core.wake_up(tags) def is_sleeping(self) -> bool: @@ -345,6 +344,14 @@ def get_metrics(self) -> list[Metric]: assert self.log_stats, "Stat logging disabled" return get_metrics_snapshot() + @property + def tokenizer(self) -> AnyTokenizer | None: + return self.processor.tokenizer + + @tokenizer.setter + def tokenizer(self, tokenizer: AnyTokenizer | None) -> None: + self.processor.tokenizer = tokenizer + def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: raise ValueError( @@ -385,10 +392,10 @@ def pin_lora(self, lora_id: int) -> bool: def collective_rpc( self, - method: Union[str, Callable[[WorkerBase], _R]], - timeout: Optional[float] = None, + method: str | Callable[[WorkerBase], _R], + timeout: float | None = None, args: tuple = (), - kwargs: Optional[dict[str, Any]] = None, + kwargs: dict[str, Any] | None = None, ) -> list[_R]: return self.engine_core.collective_rpc(method, timeout, args, kwargs) diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index ab0e44fce155..48bb5312f5d9 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -4,7 +4,6 @@ import itertools from collections.abc import Iterable from dataclasses import dataclass -from typing import Optional from vllm.logger import init_logger from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs @@ -24,19 +23,19 @@ class LogprobsProcessor: # Tokenizer for this request, # None if detokenization is disabled. - tokenizer: Optional[AnyTokenizer] + tokenizer: AnyTokenizer | None # Logprobs for this request - logprobs: Optional[SampleLogprobs] - prompt_logprobs: Optional[PromptLogprobs] - cumulative_logprob: Optional[float] - num_logprobs: Optional[int] - num_prompt_logprobs: Optional[int] + logprobs: SampleLogprobs | None + prompt_logprobs: PromptLogprobs | None + cumulative_logprob: float | None + num_logprobs: int | None + num_prompt_logprobs: int | None @classmethod def from_new_request( cls, - tokenizer: Optional[AnyTokenizer], + tokenizer: AnyTokenizer | None, request: EngineCoreRequest, ) -> "LogprobsProcessor": assert request.sampling_params is not None @@ -67,7 +66,7 @@ def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None: assert self.logprobs is not None assert self.cumulative_logprob is not None - token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists + token_ids_lst, logprobs_lst, ranks_lst, _ = logprobs_lists for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, token_ids_lst): # Detokenize (non-incrementally). @@ -148,7 +147,7 @@ def _update_prompt_logprobs( ) ) - def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]: + def pop_prompt_logprobs(self) -> PromptLogprobs | None: """Pop and return all request prompt logprobs The logprobs processor aggregates prompt chunk logprobs @@ -171,7 +170,7 @@ def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]: def _make_logprob_dict( logprobs: list[float], logprob_token_ids: list[int], - decoded_tokens: Iterable[Optional[str]], + decoded_tokens: Iterable[str | None], rank: int, num_logprobs: int, ) -> dict[int, Logprob]: diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index eb65b68969e3..44e4eadce42a 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -4,7 +4,7 @@ import asyncio from collections.abc import Iterable from dataclasses import dataclass -from typing import Any, Optional, Union, cast +from typing import Any, cast import torch @@ -36,14 +36,10 @@ class RequestOutputCollector: def __init__(self, output_kind: RequestOutputKind): self.aggregate = output_kind == RequestOutputKind.DELTA - self.output: Optional[Union[RequestOutput, PoolingRequestOutput, Exception]] = ( - None - ) + self.output: RequestOutput | PoolingRequestOutput | Exception | None = None self.ready = asyncio.Event() - def put( - self, output: Union[RequestOutput, PoolingRequestOutput, Exception] - ) -> None: + def put(self, output: RequestOutput | PoolingRequestOutput | Exception) -> None: """Non-blocking put operation.""" if self.output is None or isinstance(output, Exception): self.output = output @@ -53,7 +49,7 @@ def put( # (if n > 1) do not override each other. self.output.add(output, aggregate=self.aggregate) - async def get(self) -> Union[RequestOutput, PoolingRequestOutput]: + async def get(self) -> RequestOutput | PoolingRequestOutput: """Get operation blocks on put event.""" while (output := self.output) is None: await self.ready.wait() @@ -63,7 +59,7 @@ async def get(self) -> Union[RequestOutput, PoolingRequestOutput]: raise output return output - def get_nowait(self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: + def get_nowait(self) -> RequestOutput | PoolingRequestOutput | None: """Non-blocking get operation.""" output = self.output if output is not None: @@ -76,7 +72,7 @@ def get_nowait(self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: @dataclass class OutputProcessorOutput: - request_outputs: list[Union[RequestOutput, PoolingRequestOutput]] + request_outputs: list[RequestOutput | PoolingRequestOutput] reqs_to_abort: list[str] @@ -84,22 +80,22 @@ class RequestState: def __init__( self, request_id: str, - parent_req: Optional[ParentRequest], + parent_req: ParentRequest | None, request_index: int, - lora_name: Optional[str], + lora_name: str | None, output_kind: RequestOutputKind, - prompt: Optional[str], - prompt_token_ids: Optional[list[int]], - prompt_embeds: Optional[torch.Tensor], - logprobs_processor: Optional[LogprobsProcessor], - detokenizer: Optional[IncrementalDetokenizer], - max_tokens_param: Optional[int], + prompt: str | None, + prompt_token_ids: list[int] | None, + prompt_embeds: torch.Tensor | None, + logprobs_processor: LogprobsProcessor | None, + detokenizer: IncrementalDetokenizer | None, + max_tokens_param: int | None, arrival_time: float, - queue: Optional[RequestOutputCollector], + queue: RequestOutputCollector | None, log_stats: bool, - top_p: Optional[float] = None, - n: Optional[int] = None, - temperature: Optional[float] = None, + top_p: float | None = None, + n: int | None = None, + temperature: float | None = None, ): self.request_id = request_id self.parent_req = parent_req @@ -129,10 +125,10 @@ def from_new_request( cls, tokenizer: AnyTokenizer, request: EngineCoreRequest, - prompt: Optional[str], - parent_req: Optional[ParentRequest], + prompt: str | None, + parent_req: ParentRequest | None, request_index: int, - queue: Optional[RequestOutputCollector], + queue: RequestOutputCollector | None, log_stats: bool, ) -> "RequestState": if sampling_params := request.sampling_params: @@ -186,11 +182,11 @@ def from_new_request( def make_request_output( self, new_token_ids: list[int], - pooling_output: Optional[torch.Tensor], - finish_reason: Optional[FinishReason], - stop_reason: Union[int, str, None], - kv_transfer_params: Optional[dict[str, Any]] = None, - ) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: + pooling_output: torch.Tensor | None, + finish_reason: FinishReason | None, + stop_reason: int | str | None, + kv_transfer_params: dict[str, Any] | None = None, + ) -> RequestOutput | PoolingRequestOutput | None: finished = finish_reason is not None final_only = self.output_kind == RequestOutputKind.FINAL_ONLY @@ -222,10 +218,10 @@ def make_request_output( def _new_request_output( self, request_id: str, - outputs: Union[list[CompletionOutput], list[PoolingOutput]], + outputs: list[CompletionOutput] | list[PoolingOutput], finished: bool, - kv_transfer_params: Optional[dict[str, Any]] = None, - ) -> Union[RequestOutput, PoolingRequestOutput]: + kv_transfer_params: dict[str, Any] | None = None, + ) -> RequestOutput | PoolingRequestOutput: first_output = outputs[0] if isinstance(first_output, PoolingOutput): assert len(outputs) == 1 @@ -234,6 +230,7 @@ def _new_request_output( return PoolingRequestOutput( request_id=request_id, outputs=first_output, + num_cached_tokens=self.num_cached_tokens, prompt_token_ids=self.prompt_token_ids, finished=finished, ) @@ -264,8 +261,8 @@ def _new_request_output( def _new_completion_output( self, token_ids: list[int], - finish_reason: Optional[FinishReason], - stop_reason: Union[int, str, None], + finish_reason: FinishReason | None, + stop_reason: int | str | None, ) -> CompletionOutput: assert self.detokenizer is not None assert self.logprobs_processor is not None @@ -308,7 +305,7 @@ def __init__(self, tokenizer: AnyTokenizer, log_stats: bool): self.request_states: dict[str, RequestState] = {} self.parent_requests: dict[str, ParentRequest] = {} self.lora_states = LoRARequestStates() - self.tracer: Optional[Tracer] = None + self.tracer: Tracer | None = None def get_num_unfinished_requests(self): return len(self.request_states) @@ -360,10 +357,10 @@ def abort_requests( def add_request( self, request: EngineCoreRequest, - prompt: Optional[str], - parent_req: Optional[ParentRequest] = None, + prompt: str | None, + parent_req: ParentRequest | None = None, request_index: int = 0, - queue: Optional[RequestOutputCollector] = None, + queue: RequestOutputCollector | None = None, ) -> None: request_id = request.request_id if request_id in self.request_states: @@ -386,8 +383,8 @@ def add_request( def process_outputs( self, engine_core_outputs: list[EngineCoreOutput], - engine_core_timestamp: Optional[float] = None, - iteration_stats: Optional[IterationStats] = None, + engine_core_timestamp: float | None = None, + iteration_stats: IterationStats | None = None, ) -> OutputProcessorOutput: """ Process the EngineCoreOutputs: @@ -411,7 +408,7 @@ def process_outputs( within the loop below. """ - request_outputs: Union[list[RequestOutput], list[PoolingRequestOutput]] = [] + request_outputs: list[RequestOutput] | list[PoolingRequestOutput] = [] reqs_to_abort: list[str] = [] for engine_core_output in engine_core_outputs: req_id = engine_core_output.request_id @@ -492,7 +489,7 @@ def do_tracing( self, engine_core_output: EngineCoreOutput, req_state: RequestState, - iteration_stats: Optional[IterationStats], + iteration_stats: IterationStats | None, ) -> None: assert req_state.stats is not None assert iteration_stats is not None @@ -555,8 +552,8 @@ def _update_stats_from_output( self, req_state: RequestState, engine_core_output: EngineCoreOutput, - engine_core_timestamp: Optional[float], - iteration_stats: Optional[IterationStats], + engine_core_timestamp: float | None, + iteration_stats: IterationStats | None, ): if iteration_stats is None: return @@ -577,8 +574,8 @@ def _update_stats_from_output( def _update_stats_from_finished( self, req_state: RequestState, - finish_reason: Optional[FinishReason], - iteration_stats: Optional[IterationStats], + finish_reason: FinishReason | None, + iteration_stats: IterationStats | None, ): if iteration_stats is None: return diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index daf115c0325f..2a47befec25f 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -29,7 +29,7 @@ class ParentRequest: max_num_generation_tokens: int # To efficiently obtain child sampling params - cached_child_sampling_params: Optional[SamplingParams] + cached_child_sampling_params: SamplingParams | None def __init__(self, request_id: str, sampling_params: SamplingParams) -> None: self.request_id = request_id diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index f39e9c1eea7d..de15677aeea9 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -3,7 +3,7 @@ import time from collections.abc import Mapping -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from vllm.config import VllmConfig from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs @@ -21,6 +21,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreRequest +from vllm.v1.metrics.stats import MultiModalCacheStats from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar from vllm.v1.structured_output.backend_lm_format_enforcer import ( validate_structured_output_request_lm_format_enforcer, @@ -37,6 +38,7 @@ class Processor: def __init__( self, vllm_config: VllmConfig, + tokenizer: AnyTokenizer | None, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ) -> None: self.vllm_config = vllm_config @@ -52,16 +54,17 @@ def __init__( self.input_preprocessor = InputPreprocessor( self.model_config, + tokenizer, mm_registry, mm_processor_cache=self.mm_processor_cache, ) @property - def tokenizer(self) -> Optional[AnyTokenizer]: + def tokenizer(self) -> AnyTokenizer | None: return self.input_preprocessor.tokenizer @tokenizer.setter - def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None: + def tokenizer(self, tokenizer: AnyTokenizer | None) -> None: self.input_preprocessor.tokenizer = tokenizer def _validate_logprobs( @@ -149,7 +152,7 @@ def _validate_supported_sampling_params( def _validate_params( self, - params: Union[SamplingParams, PoolingParams], + params: SamplingParams | PoolingParams, ): """ Validate supported SamplingParam. @@ -171,7 +174,7 @@ def _validate_multi_modal_uuids(self, prompt: PromptType) -> None: auto-hashed downstream. """ - def _validate_single_prompt(single_prompt: Union[dict, str]) -> None: + def _validate_single_prompt(single_prompt: dict | str) -> None: if not isinstance(single_prompt, dict): return mm_data = single_prompt.get("multi_modal_data") @@ -211,7 +214,7 @@ def _validate_single_prompt(single_prompt: Union[dict, str]) -> None: else: _validate_single_prompt(prompt) # type: ignore[arg-type] - def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: + def _validate_lora(self, lora_request: LoRARequest | None) -> None: if lora_request is None: return @@ -306,7 +309,7 @@ def _maybe_build_mm_uuids( self, request_id: str, prompt: PromptType, - ) -> Optional[MultiModalUUIDDict]: + ) -> MultiModalUUIDDict | None: """Build per-item multimodal hash overrides when enabled. In this case, multimodal data items are identified by their request id, modality and index rather than their content. @@ -339,13 +342,13 @@ def process_inputs( self, request_id: str, prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - tokenization_kwargs: Optional[dict[str, Any]] = None, - trace_headers: Optional[Mapping[str, str]] = None, + params: SamplingParams | PoolingParams, + arrival_time: float | None = None, + lora_request: LoRARequest | None = None, + tokenization_kwargs: dict[str, Any] | None = None, + trace_headers: Mapping[str, str] | None = None, priority: int = 0, - data_parallel_rank: Optional[int] = None, + data_parallel_rank: int | None = None, ) -> EngineCoreRequest: self._validate_lora(lora_request) self._validate_params(params) @@ -442,7 +445,7 @@ def process_inputs( pooling_params = params.clone() # Multimodal related. - mm_features: Optional[list[MultiModalFeatureSpec]] = None + mm_features: list[MultiModalFeatureSpec] | None = None if decoder_inputs["type"] == "multimodal": decoder_mm_inputs = decoder_inputs["mm_kwargs"] @@ -482,7 +485,7 @@ def process_inputs( ) def _validate_model_inputs( - self, encoder_inputs: Optional[SingletonInputs], decoder_inputs: SingletonInputs + self, encoder_inputs: SingletonInputs | None, decoder_inputs: SingletonInputs ): if encoder_inputs is not None: self._validate_model_input(encoder_inputs, prompt_type="encoder") @@ -571,5 +574,8 @@ def _validate_model_input( # check that chunked prefill does not truncate them # max_batch_len = self.scheduler_config.max_num_batched_tokens - def clear_cache(self) -> None: - self.input_preprocessor.clear_cache() + def stat_mm_cache(self) -> MultiModalCacheStats | None: + return self.input_preprocessor.stat_mm_cache() + + def clear_mm_cache(self) -> None: + self.input_preprocessor.clear_mm_cache() diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index f3bc8fa19bef..ca416dbc0df9 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -4,24 +4,26 @@ import contextlib import os import weakref -from collections.abc import Iterator +from collections.abc import Callable, Iterator from dataclasses import dataclass from enum import Enum, auto from multiprocessing import Process, connection from multiprocessing.process import BaseProcess -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING from unittest.mock import patch import msgspec import zmq +from vllm import envs from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.ray.ray_env import get_env_vars_to_copy -from vllm.utils import get_mp_context, get_open_zmq_ipc_path, zmq_socket_ctx +from vllm.utils import get_mp_context +from vllm.utils.network_utils import get_open_zmq_ipc_path, zmq_socket_ctx from vllm.v1.engine.coordinator import DPCoordinator -from vllm.v1.executor.abstract import Executor +from vllm.v1.executor import Executor from vllm.v1.utils import get_engine_client_zmq_addr, shutdown if TYPE_CHECKING: @@ -55,13 +57,13 @@ class EngineZmqAddresses: # ZMQ output socket addresses for each front-end client (responses) outputs: list[str] # ZMQ input socket address of DP coordinator if applicable - coordinator_input: Optional[str] = None + coordinator_input: str | None = None # ZMQ output socket address of DP coordinator if applicable - coordinator_output: Optional[str] = None + coordinator_output: str | None = None # ZMQ socket for front-end to connect to DP coordinator. # Not used by engine, just relayed to front-end in handshake response. # Only required for external DP LB case. - frontend_stats_publish_address: Optional[str] = None + frontend_stats_publish_address: str | None = None @dataclass @@ -72,8 +74,8 @@ class EngineHandshakeMetadata: """ addresses: EngineZmqAddresses - parallel_config: dict[str, Union[int, str, list[int]]] - parallel_config_hash: Optional[str] = None + parallel_config: dict[str, int | str | list[int]] + parallel_config_hash: str | None = None class CoreEngineProcManager: @@ -93,7 +95,7 @@ def __init__( handshake_address: str, executor_class: type[Executor], log_stats: bool, - client_handshake_address: Optional[str] = None, + client_handshake_address: str | None = None, ): context = get_mp_context() common_kwargs = { @@ -220,8 +222,8 @@ def __init__( addresses: EngineZmqAddresses, executor_class: type[Executor], log_stats: bool, - placement_groups: Optional[list["PlacementGroup"]] = None, - local_dp_ranks: Optional[list[int]] = None, + placement_groups: list["PlacementGroup"] | None = None, + local_dp_ranks: list[int] | None = None, ): import copy @@ -337,13 +339,14 @@ def create_dp_placement_groups( logger.info("Creating placement groups for data parallel") dp_master_ip = vllm_config.parallel_config.data_parallel_master_ip - num_pg_to_create = vllm_config.parallel_config.data_parallel_size - local_engine_count = vllm_config.parallel_config.data_parallel_size_local + dp_size = vllm_config.parallel_config.data_parallel_size + dp_size_local = vllm_config.parallel_config.data_parallel_size_local available_resources = available_resources_per_node() world_size = vllm_config.parallel_config.world_size placement_groups: list[PlacementGroup] = [] local_dp_ranks: list[int] = [] + dp_master_ip_key = f"node:{dp_master_ip}" nodes = sorted( available_resources.values(), key=lambda x: dp_master_ip_key not in x @@ -354,49 +357,161 @@ def create_dp_placement_groups( dp_master_ip, ) device_str = current_platform.ray_device_key + n_node_devices: list[int] = [ + int(node_resources[device_str]) + for node_resources in nodes + if device_str in node_resources + ] + assert n_node_devices, f"No {device_str} found in Ray cluster." + max_device_per_node = max(n_node_devices) + + pack_strategy = envs.VLLM_RAY_DP_PACK_STRATEGY + _supported_pack_strategies = ("strict", "fill", "span") + if pack_strategy not in _supported_pack_strategies: + raise ValueError( + f"{envs.VLLM_RAY_DP_PACK_STRATEGY} is not supported. " + "Make sure to set `VLLM_RAY_DP_PACK_STRATEGY` " + f"to one of {_supported_pack_strategies}" + ) + + all2all_backend = vllm_config.parallel_config.all2all_backend + if pack_strategy == "fill" and ( + all2all_backend == "deepep_high_throughput" + or all2all_backend == "deepep_low_latency" + ): + raise ValueError( + "DeepEP kernels require EP ranks [0,7] (same for [8,15], ...) " + "to be on the same node, but VLLM_RAY_DP_PACK_STRATEGY=fill " + "does not guarantee that. " + "Please use VLLM_RAY_DP_PACK_STRATEGY=strict instead." + ) + + if pack_strategy in ("strict", "fill"): + placement_strategy = "STRICT_PACK" + else: + placement_strategy = "PACK" + assert world_size > max_device_per_node, ( + f"World size {world_size} is smaller than the " + "maximum number of devices per node " + f"{max_device_per_node}. Make sure to set " + "`VLLM_RAY_DP_PACK_STRATEGY` to `strict` or `fill`" + ) + + # if we need multiple nodes per dp group, we require for now that + # available nodes are homogenous + assert set(n_node_devices) == {max_device_per_node}, ( + f"Nodes are not homogenous, {nodes}" + ) + assert world_size % max_device_per_node == 0, ( + f"For multi-node data parallel groups, world_size ({world_size}) must " + f"be a multiple of number of devices per node ({max_device_per_node})." + ) + assert len(n_node_devices) * max_device_per_node >= world_size * dp_size, ( + f"Not enough total available nodes ({len(n_node_devices)}) " + f"and devices per node ({max_device_per_node}) " + f"to satisfy required world size {world_size} and data parallel size " + f"{dp_size}" + ) + assert dp_size_local == 1, ( + f"data-parallel-size-local {dp_size_local} should be set as the " + "default (1) for VLLM_RAY_DP_PACK_STRATEGY=span. " + "The actual data-parallel-size-local will be auto determined." + ) + + # bundles collected for a single DP rank from multiple nodes, + # for "span" pack strategy + collected_bundles = [] for node_resources in nodes: - if device_str not in node_resources: - continue - # For now, each DP rank can only be assigned to one node - # TODO(rui): support allocating a single DP rank - # to multiple nodes - available_engine_count = int(node_resources[device_str]) // world_size - if dp_master_ip_key in node_resources: - assert available_engine_count >= local_engine_count, ( - "Not enough resources to allocate DP ranks " - f"on DP master node {dp_master_ip}" - ) - for i in range(local_engine_count): - bundles = [ - {device_str: 1.0, "node:" + dp_master_ip: 0.001} - ] * world_size + [{"CPU": 1.0}] - pg = ray.util.placement_group( - name=f"dp_rank_{len(placement_groups)}", - strategy="STRICT_PACK", - bundles=bundles, + node_ip_keys = [ + key + for key in node_resources + if key != "node:__internal_head__" and key.startswith("node:") + ] + assert len(node_ip_keys) == 1, ( + "Zero or multiple node IP keys found in node resources: %s", + node_ip_keys, + ) + node_ip_key = node_ip_keys[0] + node_ip = node_ip_key.split(":")[1] + + n_device_on_node = int(node_resources.get(device_str, 0)) + if pack_strategy == "span" and n_device_on_node != 0: + # Strictly speaking, + # dp_size_available = n_device_on_node / world_size + # and is a fraction, but we use 1 for easier processing + dp_size_available = 1 + else: + dp_size_available = n_device_on_node // world_size + + if node_ip == dp_master_ip: + if dp_size_available < dp_size_local: + raise ValueError( + "Not enough resources to allocate %s DP ranks " + "on DP master node %s, possible to fit %s DP ranks", + dp_size_local, + dp_master_ip, + dp_size_available, ) - placement_groups.append(pg) - local_dp_ranks.append(i) + dp_size_to_allocate = dp_size_local + elif pack_strategy == "strict": + if dp_size_available < dp_size_local: + logger.info( + "Skipping node %s as %s DP ranks could not fit, " + "possible to fit %s DP ranks", + node_ip, + dp_size_local, + dp_size_available, + ) + continue + dp_size_to_allocate = dp_size_local else: - for i in range(available_engine_count): - if len(placement_groups) == num_pg_to_create: - break - bundles = [{device_str: 1.0}] * world_size + [{"CPU": 1.0}] - pg = ray.util.placement_group( - name=f"dp_rank_{len(placement_groups)}", - strategy="STRICT_PACK", - bundles=bundles, + # for "pack_strategy" in "fill" and "span" + # we always take everything that's available + dp_size_to_allocate = dp_size_available + + for i in range(dp_size_to_allocate): + device_bundle = [{device_str: 1.0, "node:" + node_ip: 0.001}] + if pack_strategy == "span": + collected_bundles += device_bundle * n_device_on_node + assert len(collected_bundles) <= world_size, ( + "collected_bundles should be <= world_size, " + f"but got {len(collected_bundles)=} and {world_size=}" ) - placement_groups.append(pg) - local_dp_ranks.append(i) - if len(placement_groups) < num_pg_to_create: + + # we only create a placement group if we collected enough devices + if len(collected_bundles) < world_size: + continue + + bundles = collected_bundles + [{"CPU": 1.0}] + collected_bundles = [] + else: + bundles = device_bundle * world_size + [{"CPU": 1.0}] + + pg = ray.util.placement_group( + name=f"dp_rank_{len(placement_groups)}", + strategy=placement_strategy, + bundles=bundles, + ) + placement_groups.append(pg) + local_dp_ranks.append(i) + if len(placement_groups) == dp_size: + break + + if len(placement_groups) < dp_size: raise ValueError( - f"Not enough resources to allocate {num_pg_to_create} " + f"Not enough resources to allocate {dp_size} " "placement groups, only created " f"{len(placement_groups)} placement groups. " "Available resources: " f"{available_resources}" ) + assert len(placement_groups) == dp_size, ( + f"Created {len(placement_groups)} DP placement groups, expected {dp_size}" + ) + assert len(local_dp_ranks) == dp_size, ( + f"local_dp_ranks length {len(local_dp_ranks)} does not match " + f"expected {dp_size}" + ) return placement_groups, local_dp_ranks @staticmethod @@ -634,8 +749,8 @@ def launch_core_engines( num_api_servers: int = 1, ) -> Iterator[ tuple[ - Optional[Union[CoreEngineProcManager, CoreEngineActorManager]], - Optional[DPCoordinator], + CoreEngineProcManager | CoreEngineActorManager | None, + DPCoordinator | None, EngineZmqAddresses, ] ]: @@ -788,8 +903,8 @@ def wait_for_engine_startup( core_engines: list[CoreEngine], parallel_config: ParallelConfig, cache_config: CacheConfig, - proc_manager: Optional[CoreEngineProcManager], - coord_process: Optional[Process], + proc_manager: CoreEngineProcManager | None, + coord_process: Process | None, ): # Wait for engine core process(es) to send ready messages. local_count = parallel_config.data_parallel_size_local diff --git a/vllm/v1/executor/__init__.py b/vllm/v1/executor/__init__.py index e69de29bb2d1..30d52c73791e 100644 --- a/vllm/v1/executor/__init__.py +++ b/vllm/v1/executor/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from .abstract import Executor +from .uniproc_executor import UniProcExecutor + +__all__ = ["Executor", "UniProcExecutor"] diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 064e4b2bbf18..9fe1912c73e3 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -1,30 +1,43 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import time +from abc import ABC, abstractmethod +from collections.abc import Callable from concurrent.futures import Future -from typing import Any, Callable, Optional, Union - -import torch -import torch.distributed as dist +from functools import cached_property +from typing import TYPE_CHECKING, Literal, TypeVar, overload from vllm.config import VllmConfig -from vllm.executor.executor_base import ExecutorBase -from vllm.executor.uniproc_executor import ( # noqa - ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0, -) -from vllm.executor.uniproc_executor import UniProcExecutor as UniProcExecutorV0 # noqa -from vllm.utils import resolve_obj_by_qualname +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.tasks import SupportedTask +from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.engine import ReconfigureDistributedRequest from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput +from vllm.v1.worker.worker_base import WorkerBase + +if TYPE_CHECKING: + from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase + +logger = init_logger(__name__) + +_R = TypeVar("_R") FailureCallback = Callable[[], None] -class Executor(ExecutorBase): +class Executor(ABC): + """Abstract base class for vLLM executors." + + An executor is responsible for executing the model on one device, + or it can be a distributed executor that can execute the model on multiple devices. """ - Abstract class for v1 executors, mainly define some methods for v1. - For methods shared by v0 and v1, define them in ExecutorBase""" + + uses_ray: bool = False # whether the executor uses Ray for orchestration. + supports_pp: bool = False # whether the executor supports PP @staticmethod def get_class(vllm_config: VllmConfig) -> type["Executor"]: @@ -33,16 +46,14 @@ def get_class(vllm_config: VllmConfig) -> type["Executor"]: distributed_executor_backend = parallel_config.distributed_executor_backend # distributed_executor_backend must be set in VllmConfig.__post_init__ if isinstance(distributed_executor_backend, type): - if not issubclass(distributed_executor_backend, ExecutorBase): + if not issubclass(distributed_executor_backend, Executor): raise TypeError( "distributed_executor_backend must be a subclass of " - f"ExecutorBase. Got {distributed_executor_backend}." + f"Executor. Got {distributed_executor_backend}." ) executor_class = distributed_executor_backend elif distributed_executor_backend == "ray": - from vllm.v1.executor.ray_distributed_executor import ( # noqa - RayDistributedExecutor, - ) + from vllm.v1.executor.ray_executor import RayDistributedExecutor executor_class = RayDistributedExecutor elif distributed_executor_backend == "mp": @@ -50,6 +61,8 @@ def get_class(vllm_config: VllmConfig) -> type["Executor"]: executor_class = MultiprocExecutor elif distributed_executor_backend == "uni": + from vllm.v1.executor.uniproc_executor import UniProcExecutor + executor_class = UniProcExecutor elif distributed_executor_backend == "external_launcher": # TODO: make v1 scheduling deterministic @@ -57,10 +70,10 @@ def get_class(vllm_config: VllmConfig) -> type["Executor"]: executor_class = ExecutorWithExternalLauncher elif isinstance(distributed_executor_backend, str): executor_class = resolve_obj_by_qualname(distributed_executor_backend) - if not issubclass(executor_class, ExecutorBase): + if not issubclass(executor_class, Executor): raise TypeError( "distributed_executor_backend must be a subclass of " - f"ExecutorBase. Got {executor_class}." + f"Executor. Got {executor_class}." ) else: raise ValueError( @@ -68,6 +81,29 @@ def get_class(vllm_config: VllmConfig) -> type["Executor"]: ) return executor_class + def __init__( + self, + vllm_config: VllmConfig, + ) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.observability_config = vllm_config.observability_config + self._init_executor() + self.is_sleeping = False + self.sleeping_tags: set[str] = set() + self.kv_output_aggregator: KVOutputAggregator | None = None + + @abstractmethod + def _init_executor(self) -> None: + raise NotImplementedError + def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None: """ Initialize the KV caches and begin the model execution loop of the @@ -76,7 +112,7 @@ def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None: self.collective_rpc("initialize_from_config", args=(kv_cache_configs,)) self.collective_rpc("compile_or_warm_up_model") - def register_failure_callback(self, callback: FailureCallback): + def register_failure_callback(self, callback: FailureCallback): # noqa: B027 """ Register a function to be called if the executor enters a permanent failed state. @@ -89,22 +125,78 @@ def determine_available_memory(self) -> list[int]: # in bytes def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: return self.collective_rpc("get_kv_cache_spec") + @overload def collective_rpc( self, - method: Union[str, Callable], - timeout: Optional[float] = None, + method: str | Callable[[WorkerBase], _R], + timeout: float | None = None, args: tuple = (), - kwargs: Optional[dict] = None, - non_block: bool = False, - ) -> list[Any]: + kwargs: dict | None = None, + non_block: Literal[False] = False, + ) -> list[_R]: + """ + Execute an RPC call on all workers. + + Args: + method: Name of the worker method to execute, or a callable that + is serialized and sent to all workers to execute. + + If the method is a callable, it should accept an additional + `self` argument, in addition to the arguments passed in `args` + and `kwargs`. The `self` argument will be the worker object. + timeout: Maximum time in seconds to wait for execution. Raises a + [`TimeoutError`][] on timeout. `None` means wait indefinitely. + args: Positional arguments to pass to the worker method. + kwargs: Keyword arguments to pass to the worker method. + non_block: If `True`, returns a list of Futures instead of waiting + for the results. + + Returns: + A list containing the results from each worker. + + Note: + It is recommended to use this API to only pass control messages, + and set up data-plane communication to pass data. + """ + pass + + @overload + def collective_rpc( + self, + method: str | Callable[[WorkerBase], _R], + timeout: float | None = None, + args: tuple = (), + kwargs: dict | None = None, + non_block: Literal[True] = True, + ) -> list[Future[_R]]: + pass + + @abstractmethod + def collective_rpc( + self, method, timeout=None, args=(), kwargs=None, non_block: bool = False + ): raise NotImplementedError + @overload + def execute_model( + self, + scheduler_output: SchedulerOutput, + non_block: Literal[False] = False, + ) -> ModelRunnerOutput: + pass + + @overload def execute_model( self, scheduler_output: SchedulerOutput, - non_block: bool = False, - ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: - output = self.collective_rpc( + non_block: Literal[True] = True, + ) -> Future[ModelRunnerOutput]: + pass + + def execute_model( + self, scheduler_output: SchedulerOutput, non_block: bool = False + ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: + output = self.collective_rpc( # type: ignore[call-overload] "execute_model", args=(scheduler_output,), non_block=non_block ) return output[0] @@ -112,8 +204,8 @@ def execute_model( def execute_dummy_batch(self) -> None: self.collective_rpc("execute_dummy_batch") - def take_draft_token_ids(self) -> Optional[DraftTokenIds]: - output = self.collective_rpc("take_draft_token_ids") + def take_draft_token_ids(self) -> DraftTokenIds | None: + output: list[DraftTokenIds] = self.collective_rpc("take_draft_token_ids") return output[0] @property @@ -123,19 +215,120 @@ def max_concurrent_batches(self) -> int: def profile(self, is_start: bool = True): self.collective_rpc("profile", args=(is_start,)) + def save_sharded_state( + self, + path: str, + pattern: str | None = None, + max_size: int | None = None, + ) -> None: + self.collective_rpc( + "save_sharded_state", + kwargs=dict(path=path, pattern=pattern, max_size=max_size), + ) + + @abstractmethod + def check_health(self) -> None: + """Checks if the executor is healthy. If not, it should raise an + exception.""" + raise NotImplementedError -class UniProcExecutor(UniProcExecutorV0, Executor): - pass + def shutdown(self) -> None: + """Shutdown the executor.""" + self.collective_rpc("shutdown") + def init_kv_output_aggregator(self, connector: "KVConnectorBase") -> None: + """Init KVOutputAggregator""" + self.kv_output_aggregator = KVOutputAggregator.from_connector( + connector, self.parallel_config.world_size + ) -class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor): - def determine_available_memory(self) -> list[int]: # in bytes - # same as determine_num_available_blocks in v0, - # we need to get the min across all ranks. - memory = super().determine_available_memory() - from vllm.distributed.parallel_state import get_world_group - - cpu_group = get_world_group().cpu_group - memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64) - dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN) - return [memory_tensor.item()] + @cached_property # Avoid unnecessary RPC calls + def supported_tasks(self) -> tuple[SupportedTask, ...]: + output: list[tuple[SupportedTask, ...]] + output = self.collective_rpc("get_supported_tasks") + return output[0] + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return all(self.collective_rpc("add_lora", args=(lora_request,))) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return all(self.collective_rpc("remove_lora", args=(lora_id,))) + + def pin_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return all(self.collective_rpc("pin_lora", args=(lora_id,))) + + def list_loras(self) -> set[int]: + sets: list[set[int]] = self.collective_rpc("list_loras") + for s in sets: + assert s == sets[0], "All workers should have the same LORAs." + return sets[0] + + def reset_mm_cache(self) -> None: + """Reset the multi-modal cache in each worker.""" + self.collective_rpc("reset_mm_cache") + + def start_profile(self) -> None: + self.collective_rpc("start_profile") + + def stop_profile(self) -> None: + self.collective_rpc("stop_profile") + + def sleep(self, level: int = 1): + if self.is_sleeping: + logger.warning("Executor is already sleeping.") + return + time_before_sleep = time.perf_counter() + self.collective_rpc("sleep", kwargs=dict(level=level)) + time_after_sleep = time.perf_counter() + self.sleeping_tags = {"weights", "kv_cache"} + self.is_sleeping = True + logger.info( + "It took %.6f seconds to fall asleep.", time_after_sleep - time_before_sleep + ) + + def wake_up(self, tags: list[str] | None = None): + if not self.is_sleeping: + logger.warning("Executor is not sleeping.") + return + if tags: + for tag in tags: + if tag not in self.sleeping_tags: + logger.warning( + "Tag %s is not in sleeping tags %s", tag, self.sleeping_tags + ) + return + time_before_wakeup = time.perf_counter() + self.collective_rpc("wake_up", kwargs=dict(tags=tags)) + time_after_wakeup = time.perf_counter() + logger.info( + "It took %.6f seconds to wake up tags %s.", + time_after_wakeup - time_before_wakeup, + tags if tags is not None else self.sleeping_tags, + ) + if tags: + for tag in tags: + self.sleeping_tags.remove(tag) + else: + self.sleeping_tags.clear() + if not self.sleeping_tags: + self.is_sleeping = False + + def reinitialize_distributed( + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: + raise NotImplementedError + + +from vllm.v1.executor.uniproc_executor import ( # noqa: E402 + ExecutorWithExternalLauncher as _ExecutorWithExternalLauncher, +) +from vllm.v1.executor.uniproc_executor import ( # noqa: E402 + UniProcExecutor as _UniProcExecutor, +) + +# For backwards compatibility. +UniProcExecutor = _UniProcExecutor +ExecutorWithExternalLauncher = _ExecutorWithExternalLauncher diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 062b6042693b..1b4b9c4550f7 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -9,6 +9,7 @@ import time import traceback import weakref +from collections.abc import Callable from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from enum import Enum, auto @@ -17,7 +18,7 @@ from multiprocessing.process import BaseProcess from multiprocessing.synchronize import Lock as LockType from threading import Thread -from typing import Any, Callable, Optional, Union, cast +from typing import Any, cast import cloudpickle import torch @@ -32,21 +33,17 @@ get_pp_group, get_tp_group, ) +from vllm.envs import enable_envs_cache from vllm.logger import init_logger -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.cache import worker_receiver_cache_from_config -from vllm.utils import ( - _maybe_force_spawn, - decorate_logs, +from vllm.utils import _maybe_force_spawn, get_mp_context +from vllm.utils.network_utils import ( get_distributed_init_method, get_loopback_ip, - get_mp_context, get_open_port, - set_process_title, ) +from vllm.utils.system_utils import decorate_logs, set_process_title from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.executor.abstract import Executor, FailureCallback -from vllm.v1.executor.utils import get_and_update_mm_cache from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput from vllm.v1.worker.worker_base import WorkerWrapperBase @@ -62,8 +59,8 @@ def _init_executor(self) -> None: self._finalizer = weakref.finalize(self, self.shutdown) self.is_failed = False self.shutdown_event = threading.Event() - self.failure_callback: Optional[FailureCallback] = None - self.io_thread_pool: Optional[ThreadPoolExecutor] = None + self.failure_callback: FailureCallback | None = None + self.io_thread_pool: ThreadPoolExecutor | None = None self.world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size @@ -178,11 +175,11 @@ def register_failure_callback(self, callback: FailureCallback): else: self.failure_callback = callback - def execute_model( + def execute_model( # type: ignore[override] self, scheduler_output: SchedulerOutput, non_block: bool = False, - ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: + ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: if not self.has_connector: # get output only from a single worker (output_rank) (output,) = self.collective_rpc( @@ -203,6 +200,7 @@ def execute_model( ) # aggregate all workers output to a single output + assert self.kv_output_aggregator is not None if non_block: return self.kv_output_aggregator.async_aggregate(outputs, self.output_rank) return self.kv_output_aggregator.aggregate(outputs, self.output_rank) @@ -210,7 +208,7 @@ def execute_model( def execute_dummy_batch(self) -> None: self.collective_rpc("execute_dummy_batch", unique_reply_rank=self.output_rank) - def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + def take_draft_token_ids(self) -> DraftTokenIds | None: # OPTIMIZATION: Get output only from a single worker (output_rank) outputs = self.collective_rpc( "take_draft_token_ids", unique_reply_rank=self.output_rank @@ -219,12 +217,12 @@ def take_draft_token_ids(self) -> Optional[DraftTokenIds]: def collective_rpc( self, - method: Union[str, Callable], - timeout: Optional[float] = None, + method: str | Callable, + timeout: float | None = None, args: tuple = (), - kwargs: Optional[dict] = None, + kwargs: dict | None = None, non_block: bool = False, - unique_reply_rank: Optional[int] = None, + unique_reply_rank: int | None = None, ) -> list[Any]: if self.is_failed: raise RuntimeError("Executor failed.") @@ -255,8 +253,8 @@ def collective_rpc( def get_response( w: WorkerProcHandle, - dequeue_timeout: Optional[float] = None, - cancel_event: Optional[threading.Event] = None, + dequeue_timeout: float | None = None, + cancel_event: threading.Event | None = None, ): status, result = w.worker_response_mq.dequeue( timeout=dequeue_timeout, cancel=cancel_event @@ -373,7 +371,7 @@ class UnreadyWorkerProcHandle: proc: BaseProcess rank: int ready_pipe: Connection - death_writer: Optional[Connection] = None + death_writer: Connection | None = None @dataclass @@ -381,7 +379,7 @@ class WorkerProcHandle: proc: BaseProcess rank: int worker_response_mq: MessageQueue # The worker process writes to this MQ - death_writer: Optional[Connection] = None + death_writer: Connection | None = None @classmethod def from_unready_handle( @@ -422,6 +420,7 @@ def __init__( "rank": rank, "distributed_init_method": distributed_init_method, "is_driver_worker": is_driver_worker, + "shared_worker_lock": shared_worker_lock, } wrapper.init_worker(all_kwargs) self.worker = wrapper @@ -445,11 +444,6 @@ def __init__( ) self.async_output_copy_thread.start() - # Initialize multimodal receiver cache if needed - self.mm_receiver_cache = worker_receiver_cache_from_config( - vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock - ) - # Initialize device self.worker.init_device() @@ -461,6 +455,10 @@ def __init__( # Load model self.worker.load_model() + # Enable environment variable cache (e.g. assume no more + # environment variable overrides after this point) + enable_envs_cache() + @staticmethod def make_worker_process( vllm_config: VllmConfig, @@ -512,7 +510,7 @@ def wait_for_ready( ) pipes = {handle.ready_pipe: handle for handle in unready_proc_handles} - ready_proc_handles: list[Optional[WorkerProcHandle]] = [None] * len( + ready_proc_handles: list[WorkerProcHandle | None] = [None] * len( unready_proc_handles ) while pipes: @@ -681,7 +679,7 @@ def async_output_busy_loop(self): output = self.async_output_queue.get() self.enqueue_output(output) - def worker_busy_loop(self, cancel: Optional[threading.Event] = None): + def worker_busy_loop(self, cancel: threading.Event | None = None): """Main busy loop for Multiprocessing Workers""" while True: method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue( @@ -692,12 +690,7 @@ def worker_busy_loop(self, cancel: Optional[threading.Event] = None): func = getattr(self.worker, method) elif isinstance(method, bytes): func = partial(cloudpickle.loads(method), self.worker) - # retrieve from shm cache if available - if ( - self.mm_receiver_cache is not None - and func.__name__ == "execute_model" - ): - get_and_update_mm_cache(self.mm_receiver_cache, args) + output = func(*args, **kwargs) except Exception as e: # Notes have been introduced in python 3.11 diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index e2c2bfd45d7b..9a56c093ad69 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -1,112 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from concurrent.futures import Future -from typing import Optional, Union - -from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator -from vllm.executor.ray_distributed_executor import ( # noqa - RayDistributedExecutor as RayDistributedExecutorV0, +from vllm.v1.executor.ray_executor import ( + RayDistributedExecutor as _RayDistributedExecutor, ) -from vllm.logger import init_logger -from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType -from vllm.v1.executor.abstract import Executor -from vllm.v1.outputs import ModelRunnerOutput - -logger = init_logger(__name__) - - -class FutureWrapper(Future): - """A wrapper around Ray output reference to meet the interface - of .execute_model(): The top level (core busy loop) expects .result() api - to block and return a single output. - - If aggregator is provided, the outputs from all workers are aggregated upon - the result() call. If not only the first worker's output is returned. - """ - - def __init__(self, refs, aggregator: Optional[KVOutputAggregator] = None): - super().__init__() - self.refs = refs - self.aggregator = aggregator - - def result(self, timeout=None): - if timeout is not None: - raise NotImplementedError("timeout is not supported") - - if self.aggregator is None: - return self.refs[0].get() - - outputs = [ref.get() for ref in self.refs] - return self.aggregator.aggregate(outputs, output_rank=0) - - -class RayDistributedExecutor(RayDistributedExecutorV0, Executor): - """Ray distributed executor using Ray Compiled Graphs.""" - - supports_pp: bool = True - - def _init_executor(self) -> None: - super()._init_executor() - - # KV connector setup - self.has_connector = self.vllm_config.kv_transfer_config is not None - - @property - def max_concurrent_batches(self) -> int: - """Ray distributed executor supports pipeline parallelism, - meaning that it allows PP size batches to be executed concurrently. - """ - if self.scheduler_config.async_scheduling: - return 2 - return self.parallel_config.pipeline_parallel_size - - def execute_model( - self, - scheduler_output: SchedulerOutput, - non_block: bool = False, - ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: - """Execute the model on the Ray workers. - - Args: - scheduler_output: The scheduler output to execute. - non_block: If True, the method will return a Future. - - Returns: - The model runner output. - """ - # Build the compiled DAG for the first time. - if self.forward_dag is None: # type: ignore - self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) - - refs = self.forward_dag.execute(scheduler_output) # type: ignore - - if not self.has_connector: - # Get output only from a single worker (output_rank) - # When PP is not used, we block here until the result is available. - if not non_block: - return refs[0].get() - - # When PP is used, we return a FutureWrapper immediately so that - # the scheduler can yield to the next batch. - return FutureWrapper(refs) - - # Get output from all workers when connector is present - if not non_block: - # Block and get results from all workers - outputs = [ref.get() for ref in refs] - return self.kv_output_aggregator.aggregate(outputs) - - # Return a future that will aggregate outputs from all workers - return FutureWrapper(refs, self.kv_output_aggregator) - def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest - ) -> None: - self._run_workers("reinitialize_distributed", reconfig_request) - if ( - reconfig_request.new_data_parallel_rank - == ReconfigureRankType.SHUTDOWN_CURRENT_RANK - ): - self.shutdown() +# For backwards compatibility. +RayDistributedExecutor = _RayDistributedExecutor diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_executor.py similarity index 57% rename from vllm/executor/ray_distributed_executor.py rename to vllm/v1/executor/ray_executor.py index 6a9608d70b69..a4823acc8764 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -1,31 +1,34 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio import os from collections import defaultdict +from collections.abc import Callable +from concurrent.futures import Future from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any import cloudpickle -import msgspec import vllm.envs as envs -from vllm.executor.executor_base import DistributedExecutorBase -from vllm.executor.msgspec_utils import encode_hook -from vllm.executor.ray_utils import RayWorkerWrapper, initialize_ray_cluster, ray from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.ray.ray_env import get_env_vars_to_copy -from vllm.sequence import ExecuteModelRequest -from vllm.utils import ( - _run_task_with_lock, +from vllm.utils.network_utils import ( get_distributed_init_method, get_ip, get_open_port, - make_async, ) -from vllm.v1.outputs import SamplerOutput +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType +from vllm.v1.executor.abstract import Executor +from vllm.v1.executor.ray_utils import ( + FutureWrapper, + RayWorkerWrapper, + initialize_ray_cluster, + ray, +) +from vllm.v1.outputs import ModelRunnerOutput if ray is not None: from ray.actor import ActorHandle @@ -53,7 +56,7 @@ class RayWorkerMetaData: ip: str = "" -class RayDistributedExecutor(DistributedExecutorBase): +class RayDistributedExecutor(Executor): """Ray-based distributed executor""" # These env vars are worker-specific, therefore are NOT copied @@ -69,37 +72,14 @@ class RayDistributedExecutor(DistributedExecutorBase): ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"} uses_ray: bool = True + supports_pp: bool = True def _init_executor(self) -> None: - self.forward_dag: Optional[ray.dag.CompiledDAG] = None - if envs.VLLM_USE_V1: - # V1 uses SPMD worker and compiled DAG - os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1" - os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1" - - # For TPU or XPU, avoid compiling NVIDIA's NCCL - if current_platform.is_tpu() or current_platform.is_xpu(): - os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm" - - # If the env var is set, it uses the Ray's compiled DAG API - # which optimizes the control plane overhead. - # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. - # Currently, this requires USE_RAY_SPMD_WORKER=True. - self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG - # If the env var is set, then we do not distinguish between the - # "driver worker" vs other workers. Also, the rank 0 worker will - # be executed in a remote Ray worker. Currently this requires - # USE_RAY_COMPILED_DAG=True. - self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER - if self.use_ray_compiled_dag: - assert self.use_ray_spmd_worker, ( - "VLLM_USE_RAY_COMPILED_DAG=1 requires VLLM_USE_RAY_SPMD_WORKER=1" - ) - if self.use_ray_spmd_worker: - # TODO: Support SPMD worker for non-DAG Ray executor. - assert self.use_ray_compiled_dag, ( - "VLLM_USE_RAY_SPMD_WORKER=1 requires VLLM_USE_RAY_COMPILED_DAG=1" - ) + self.forward_dag: ray.dag.CompiledDAG | None = None + + # For TPU or XPU, avoid compiling NVIDIA's NCCL + if current_platform.is_tpu() or current_platform.is_xpu(): + os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm" assert self.uses_ray initialize_ray_cluster(self.parallel_config) @@ -113,13 +93,17 @@ def _init_executor(self) -> None: # Create the parallel GPU workers. self._init_workers_ray(placement_group) - self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) - self.output_decoder = msgspec.msgpack.Decoder(Optional[list[SamplerOutput]]) - self.use_v1 = envs.VLLM_USE_V1 + # KV connector setup + self.has_connector = self.vllm_config.kv_transfer_config is not None - self.pp_locks: Optional[list[asyncio.Lock]] = None - if not self.use_ray_compiled_dag: - self.driver_exec_method = make_async(self.driver_worker.execute_method) + @property + def max_concurrent_batches(self) -> int: + """Ray distributed executor supports pipeline parallelism, + meaning that it allows PP size batches to be executed concurrently. + """ + if self.scheduler_config.async_scheduling: + return 2 + return self.parallel_config.pipeline_parallel_size def shutdown(self) -> None: if logger: @@ -162,7 +146,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwar # The driver dummy worker does not actually use any resources. # It holds the resource for the driver worker. - self.driver_dummy_worker: Optional[RayWorkerWrapper] = None + self.driver_dummy_worker: RayWorkerWrapper | None = None # The remaining workers are the actual ray actors. self.workers: list[RayWorkerWrapper] = [] @@ -176,8 +160,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwar ray_remote_kwargs ) - logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) - # Create the workers. bundle_indices: list[int] if envs.VLLM_RAY_BUNDLE_INDICES: @@ -216,7 +198,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwar num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, rpc_rank=rank) + )(RayWorkerWrapper).remote( # type: ignore[attr-defined] + vllm_config=self.vllm_config, rpc_rank=rank + ) else: worker = ray.remote( num_cpus=0, @@ -224,7 +208,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwar resources={current_platform.ray_device_key: num_gpus}, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, rpc_rank=rank) + )(RayWorkerWrapper).remote( # type: ignore[attr-defined] + vllm_config=self.vllm_config, rpc_rank=rank + ) worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank)) worker_ips = ray.get( @@ -237,30 +223,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwar for each, ip in zip(worker_metadata, worker_ips): each.ip = ip - if not self.use_ray_spmd_worker: - for i, each in enumerate(worker_metadata): - # find and remove the dummy worker from the list - worker = each.worker - worker_ip = each.ip - if self.driver_dummy_worker is None and worker_ip == driver_ip: - # If the worker is on the same node as the driver, we use it - # as the resource holder for the driver process. - self.driver_dummy_worker = worker - self.driver_worker = RayWorkerWrapper( - vllm_config=self.vllm_config, rpc_rank=0 - ) - worker_metadata.pop(i) - break - logger.debug("workers: %s", worker_metadata) logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) - if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: - raise ValueError( - "Ray does not allocate any GPUs on the driver node." - f"Driver IP: {driver_ip}, worker IPs: {worker_ips}." - "Consider adjusting the Ray placement group or running " - "the driver on a GPU node." - ) ip_counts: dict[str, int] = {} for ip in worker_ips: @@ -277,7 +241,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): should be placed first. """ ip = item.ip - return (0 if ip == driver_ip else 1, ip_counts[ip], ip) + return 0 if ip == driver_ip else 1, ip_counts[ip], ip # After sorting, the workers on the same node will be # close to each other, and the workers on the driver @@ -285,14 +249,13 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): sorted_worker_metadata = sorted( worker_metadata, key=sort_by_driver_then_worker_ip ) - start_rank = 0 if self.use_ray_spmd_worker else 1 for i, item in enumerate(sorted_worker_metadata): - item.adjusted_rank = i + start_rank + item.adjusted_rank = i self.workers = [item.worker for item in sorted_worker_metadata] rerank_mapping = { item.created_rank: item.adjusted_rank for item in sorted_worker_metadata } - self._run_workers("adjust_rank", rerank_mapping) + self.collective_rpc("adjust_rank", args=(rerank_mapping,)) # Get the set of GPU IDs used on each node. worker_node_and_gpu_ids = [] @@ -302,7 +265,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): continue worker_node_and_gpu_ids.append( ray.get(worker.get_node_and_gpu_ids.remote()) - ) # type: ignore + ) # type: ignore[attr-defined] node_workers = defaultdict(list) # node id -> list of worker ranks node_gpus = defaultdict(list) # node id -> list of gpu ids @@ -361,8 +324,8 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): self._env_vars_for_all_workers = all_args_to_update_environment_variables - self._run_workers( - "update_environment_variables", self._get_env_vars_to_be_updated() + self.collective_rpc( + "update_environment_variables", args=(self._get_env_vars_to_be_updated(),) ) if len(node_gpus) == 1: @@ -392,136 +355,95 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): or (rank % self.parallel_config.tensor_parallel_size == 0), ) all_kwargs.append(kwargs) - self._run_workers("init_worker", all_kwargs) + self.collective_rpc("init_worker", args=(all_kwargs,)) + + self.collective_rpc("init_device") + self.collective_rpc("load_model") + + for pp_rank in range(self.parallel_config.pipeline_parallel_size): + self.pp_tp_workers.append([]) + for tp_rank in range(self.parallel_config.tensor_parallel_size): + # PP=2, TP=4 + # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] + rank = (pp_rank * self.parallel_config.tensor_parallel_size) + tp_rank + assert len(self.pp_tp_workers[pp_rank]) == tp_rank + assert pp_rank < len(self.pp_tp_workers) + self.pp_tp_workers[pp_rank].append(self.workers[rank]) + + def reinitialize_distributed( + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: + self.collective_rpc("reinitialize_distributed", args=(reconfig_request,)) + if ( + reconfig_request.new_data_parallel_rank + == ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ): + self.shutdown() + + def execute_model( # type: ignore[override] + self, scheduler_output: SchedulerOutput, non_block: bool = False + ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: + """Execute the model on the Ray workers. - self._run_workers("init_device") - self._run_workers( - "load_model", - max_concurrent_workers=self.parallel_config.max_parallel_loading_workers, - ) + Args: + scheduler_output: The scheduler output to execute. + non_block: If True, the method will return a Future. - if self.use_ray_spmd_worker: - for pp_rank in range(self.parallel_config.pipeline_parallel_size): - self.pp_tp_workers.append([]) - for tp_rank in range(self.parallel_config.tensor_parallel_size): - # PP=2, TP=4 - # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] - rank = ( - pp_rank * self.parallel_config.tensor_parallel_size - ) + tp_rank - assert len(self.pp_tp_workers[pp_rank]) == tp_rank - assert pp_rank < len(self.pp_tp_workers) - self.pp_tp_workers[pp_rank].append(self.workers[rank]) - - # This is the list of workers that are rank 0 of each TP group EXCEPT - # global rank 0. These are the workers that will broadcast to the - # rest of the workers. - self.tp_driver_workers: list[RayWorkerWrapper] = [] - # This is the list of workers that are not drivers and not the first - # worker in a TP group. These are the workers that will be - # broadcasted to. - self.non_driver_workers: list[RayWorkerWrapper] = [] - - # Enforce rank order for correct rank to return final output. - for index, worker in enumerate(self.workers): - # The driver worker is rank 0 and not in self.workers. - rank = index + 1 - if rank % self.parallel_config.tensor_parallel_size == 0: - self.tp_driver_workers.append(worker) - else: - self.non_driver_workers.append(worker) + Returns: + The model runner output. + """ + # Build the compiled DAG for the first time. + if self.forward_dag is None: # type: ignore + self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) - def _driver_execute_model( - self, execute_model_req: Optional[ExecuteModelRequest] - ) -> Optional[list[SamplerOutput]]: - """Run execute_model in the driver worker. + refs = self.forward_dag.execute(scheduler_output) # type: ignore - Passing None will cause the driver to stop the model execution - loop running in each of the remote workers. - """ - assert not self.use_ray_spmd_worker, ( - "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1" - ) - return self.driver_worker.execute_method("execute_model", execute_model_req) + if not self.has_connector: + # Get output only from a single worker (output_rank) + # When PP is not used, we block here until the result is available. + if not non_block: + return refs[0].get() - def execute_model( - self, execute_model_req: ExecuteModelRequest - ) -> list[SamplerOutput]: - if not self.use_ray_spmd_worker: - return super().execute_model(execute_model_req) + # When PP is used, we return a FutureWrapper immediately so that + # the scheduler can yield to the next batch. + return FutureWrapper(refs) - if self.forward_dag is None: - self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) + # Get output from all workers when connector is present + assert self.kv_output_aggregator is not None + if not non_block: + # Block and get results from all workers + outputs = [ref.get() for ref in refs] + return self.kv_output_aggregator.aggregate(outputs) - if self.use_v1: - serialized_data = execute_model_req - else: - serialized_data = self.input_encoder.encode(execute_model_req) - outputs = ray.get(self.forward_dag.execute(serialized_data)) - output = outputs[0] if self.use_v1 else self.output_decoder.decode(outputs[0]) - return output + # Return a future that will aggregate outputs from all workers + return FutureWrapper(refs, self.kv_output_aggregator) - def _run_workers( + def collective_rpc( self, - method: Union[str, Callable], - *args, - async_run_tensor_parallel_workers_only: bool = False, - max_concurrent_workers: Optional[int] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers. Can be used in the following - ways: - - Args: - - async_run_tensor_parallel_workers_only: If True the method will be - run only in the remote TP workers, not the driver worker. - It will also be run asynchronously and return a list of futures - rather than blocking on the results. - - args/kwargs: All workers share the same args/kwargs - """ + method: str | Callable, + timeout: float | None = None, + args: tuple = (), + kwargs: dict[str, Any] | None = None, + non_block: bool = False, + ) -> list[Any]: + """Runs the given method on all workers.""" sent_method = method if isinstance(method, str) else cloudpickle.dumps(method) del method - if self.use_ray_spmd_worker: - assert not async_run_tensor_parallel_workers_only, ( - "async_run_tensor_parallel_workers_only is not supported for spmd mode." - ) - if max_concurrent_workers: - raise NotImplementedError("max_concurrent_workers is not supported yet.") - - # Start the ray workers first. - ray_workers = self.workers - if async_run_tensor_parallel_workers_only: - ray_workers = self.non_driver_workers + if kwargs is None: + kwargs = {} ray_worker_outputs = [ - worker.execute_method.remote(sent_method, *args, **kwargs) - for worker in ray_workers + worker.execute_method.remote( # type: ignore[attr-defined] + sent_method, *args, **kwargs + ) + for worker in self.workers ] - if async_run_tensor_parallel_workers_only: - # Just return futures - return ray_worker_outputs - - driver_worker_output = [] - # In SPMD mode, the driver worker is the same as any other worker, - # so we only explicitly execute on the driver worker if using a - # non-SPMD worker class. - if not self.use_ray_spmd_worker: - # Start the driver worker after all the ray workers. - driver_worker_output = [ - self.driver_worker.execute_method(sent_method, *args, **kwargs) - ] - # Get the results of the ray workers. - if self.workers: - ray_worker_outputs = ray.get(ray_worker_outputs) - - return driver_worker_output + ray_worker_outputs + if non_block: + return [FutureWrapper((output,)) for output in ray_worker_outputs] - def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: - """Wait for futures returned from _run_workers() with - async_run_remote_workers_only to complete.""" - ray.get(parallel_worker_tasks) + return ray.get(ray_worker_outputs, timeout=timeout) def _check_ray_cgraph_installation(self): import importlib.metadata @@ -589,13 +511,6 @@ def _compiled_ray_dag(self, enable_asyncio: bool): with InputNode() as input_data: # Example DAG: PP=2, TP=4 # - # For V0: - # ExecuteModelRequest -> 0 -> (ExecuteModelReq, IntermediateTensors) -> 4 -> SamplerOutput # noqa: E501 - # ExecuteModelRequest -> 1 -> (ExecuteModelReq, IntermediateTensors) -> 5 -> SamplerOutput # noqa: E501 - # ExecuteModelRequest -> 2 -> (ExecuteModelReq, IntermediateTensors) -> 6 -> SamplerOutput # noqa: E501 - # ExecuteModelRequest -> 3 -> (ExecuteModelReq, IntermediateTensors) -> 7 -> SamplerOutput # noqa: E501 - # - # For V1: # SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput # noqa: E501 # SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput # noqa: E501 # SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput # noqa: E501 @@ -607,20 +522,10 @@ def _compiled_ray_dag(self, enable_asyncio: bool): for pp_rank, tp_group in enumerate(self.pp_tp_workers): # Each PP worker takes in the output of the previous PP worker, # and the TP group executes in SPMD fashion. - if self.use_v1: - outputs = [ - worker.execute_model_ray.bind( # type: ignore[attr-defined] - outputs[i] - ) - for i, worker in enumerate(tp_group) - ] - else: - outputs = [ - worker.execute_model_spmd.bind( # type: ignore[attr-defined] - outputs[i] - ) - for i, worker in enumerate(tp_group) - ] + outputs = [ + worker.execute_model_ray.bind(outputs[i]) # type: ignore[attr-defined] + for i, worker in enumerate(tp_group) + ] last_pp_rank = len(self.pp_tp_workers) - 1 if ( @@ -668,75 +573,6 @@ def _compiled_ray_dag(self, enable_asyncio: bool): def __del__(self): self.shutdown() - async def execute_model_async( - self, execute_model_req: ExecuteModelRequest - ) -> list[SamplerOutput]: - if not self.use_ray_spmd_worker: - return await super().execute_model_async(execute_model_req) - - if self.forward_dag is None: - self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) - - serialized_data = self.input_encoder.encode(execute_model_req) - dag_future = await self.forward_dag.execute_async(serialized_data) - output = await dag_future[0] - return self.output_decoder.decode(output) - - async def _driver_execute_model_async( - self, execute_model_req: Optional[ExecuteModelRequest] = None - ) -> list[SamplerOutput]: - assert not self.use_ray_spmd_worker, ( - "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1" - ) - if not self.tp_driver_workers: - return await self.driver_exec_method("execute_model", execute_model_req) - if self.pp_locks is None: - # This locks each pipeline parallel stage so multiple virtual - # engines can't execute on the same stage at the same time - # We create the locks here to avoid creating them in the constructor - # which uses a different asyncio loop. - self.pp_locks = [ - asyncio.Lock() - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - - tasks = [ - asyncio.create_task( - _run_task_with_lock( - self.driver_exec_method, - self.pp_locks[0], - "execute_model", - execute_model_req, - ) - ) - ] - for pp_rank, driver_worker in enumerate(self.tp_driver_workers, start=1): - tasks.append( - asyncio.create_task( - _run_task_with_lock( - driver_worker.execute_method.remote, - self.pp_locks[pp_rank], - "execute_model", - execute_model_req, - ) - ) - ) - - results = await asyncio.gather(*tasks) - - # Only the last PP stage has the final results. - return results[-1] - - async def _start_worker_execution_loop(self): - assert not self.use_ray_spmd_worker, ( - "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1" - ) - coros = [ - worker.execute_method.remote("start_worker_execution_loop") - for worker in self.non_driver_workers - ] - return await asyncio.gather(*coros) - def check_health(self) -> None: # Assume that the Ray workers are healthy. # TODO: check the health of the Ray workers diff --git a/vllm/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py similarity index 88% rename from vllm/executor/ray_utils.py rename to vllm/v1/executor/ray_utils.py index c3c8a70678ad..518f1582faeb 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -4,18 +4,17 @@ import os import time from collections import defaultdict -from typing import TYPE_CHECKING, Optional, Union - -import msgspec +from concurrent.futures import Future +from typing import TYPE_CHECKING, Union import vllm.platforms from vllm.config import ParallelConfig from vllm.distributed import get_pp_group -from vllm.executor.msgspec_utils import decode_hook, encode_hook +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.sequence import ExecuteModelRequest, IntermediateTensors -from vllm.utils import get_ip +from vllm.sequence import IntermediateTensors +from vllm.utils.network_utils import get_ip from vllm.v1.outputs import AsyncModelRunnerOutput from vllm.v1.worker.worker_base import WorkerWrapperBase @@ -51,11 +50,6 @@ def __init__(self, *args, **kwargs) -> None: # that thread. self.compiled_dag_cuda_device_set = False - self.input_decoder = msgspec.msgpack.Decoder( - ExecuteModelRequest, dec_hook=decode_hook - ) - self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) - def get_node_ip(self) -> str: return get_ip() @@ -70,44 +64,6 @@ def get_node_and_gpu_ids(self) -> tuple[str, list[int]]: gpu_ids = ray.get_runtime_context().get_accelerator_ids()[device_key] return node_id, gpu_ids - def execute_model_spmd( - self, - req_or_tuple: Union[bytes, tuple[bytes, Optional[IntermediateTensors]]], - ) -> bytes: - """Execute model in SPMD fashion: used only when SPMD worker and - compiled DAG are both enabled. - - Args: - req_or_tuple: A request or a tuple containing the - request and intermediate tensors. Intermediate tensors are - None unless if it is provided because it is > 0 pipeline - stage. The request is serialized by msgspec. - """ - if isinstance(req_or_tuple, bytes): - serialized_req, intermediate_tensors = req_or_tuple, None - else: - serialized_req, intermediate_tensors = req_or_tuple - - execute_model_req = self.input_decoder.decode(serialized_req) - - # TODO(swang): This is needed right now because Ray Compiled Graph - # executes on a background thread, so we need to reset torch's - # current device. - if not self.compiled_dag_cuda_device_set: - current_platform.set_device(self.worker.device) - self.compiled_dag_cuda_device_set = True - - output = self.worker._execute_model_spmd( - execute_model_req, intermediate_tensors - ) - # Pipeline model request and output to the next pipeline stage. - if isinstance(output, IntermediateTensors): - output = serialized_req, output - else: - output = self.output_encoder.encode(output) - - return output - def setup_device_if_necessary(self): # TODO(swang): This is needed right now because Ray CG executes # on a background thread, so we need to reset torch's current @@ -119,6 +75,7 @@ def setup_device_if_necessary(self): # Not needed pass else: + assert self.worker.device is not None current_platform.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True @@ -139,6 +96,7 @@ def execute_model_ray( scheduler_output, intermediate_tensors = scheduler_output else: scheduler_output, intermediate_tensors = scheduler_output, None + assert self.worker.model_runner is not None output = self.worker.model_runner.execute_model( scheduler_output, intermediate_tensors ) @@ -169,6 +127,31 @@ def override_env_vars(self, vars: dict[str, str]): RayWorkerWrapper = None # type: ignore +class FutureWrapper(Future): + """A wrapper around Ray output reference to meet the interface + of .execute_model(): The top level (core busy loop) expects .result() api + to block and return a single output. + + If aggregator is provided, the outputs from all workers are aggregated upon + the result() call. If not only the first worker's output is returned. + """ + + def __init__(self, refs, aggregator: KVOutputAggregator | None = None): + super().__init__() + self.refs = refs + self.aggregator = aggregator + + def result(self, timeout=None): + if timeout is not None: + raise NotImplementedError("timeout is not supported") + + if self.aggregator is None: + return self.refs[0].get() + + outputs = [ref.get() for ref in self.refs] + return self.aggregator.aggregate(outputs, output_rank=0) + + def ray_is_available() -> bool: """Returns True if Ray is available.""" return ray is not None @@ -300,7 +283,7 @@ def _wait_until_pg_removed(current_placement_group: "PlacementGroup"): def initialize_ray_cluster( parallel_config: ParallelConfig, - ray_address: Optional[str] = None, + ray_address: str | None = None, ): """Initialize the distributed cluster with Ray. diff --git a/vllm/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py similarity index 69% rename from vllm/executor/uniproc_executor.py rename to vllm/v1/executor/uniproc_executor.py index 8206f23d1878..0d072172fdf3 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -1,56 +1,50 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +from collections.abc import Callable from concurrent.futures import Future, ThreadPoolExecutor from functools import cached_property from multiprocessing import Lock -from typing import Any, Callable, Optional, Union +from typing import Any import torch import torch.distributed as dist import vllm.envs as envs -from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.cache import worker_receiver_cache_from_config -from vllm.utils import get_distributed_init_method, get_ip, get_open_port, run_method +from vllm.utils import run_method +from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType -from vllm.v1.executor.utils import get_and_update_mm_cache +from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import AsyncModelRunnerOutput from vllm.v1.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) -class UniProcExecutor(ExecutorBase): - uses_ray: bool = False - +class UniProcExecutor(Executor): def _init_executor(self) -> None: """Initialize the worker and load the model.""" self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0) distributed_init_method, rank, local_rank = self._distributed_args() - is_driver_worker = True kwargs = dict( vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, - is_driver_worker=is_driver_worker, - ) - self.mm_receiver_cache = worker_receiver_cache_from_config( - self.vllm_config, MULTIMODAL_REGISTRY, Lock() + is_driver_worker=True, + shared_worker_lock=Lock(), ) - self.async_output_thread: Optional[ThreadPoolExecutor] = None + self.async_output_thread: ThreadPoolExecutor | None = None if self.max_concurrent_batches > 1: self.async_output_thread = ThreadPoolExecutor( max_workers=1, thread_name_prefix="WorkerAsyncOutput" ) - self.collective_rpc("init_worker", args=([kwargs],)) - self.collective_rpc("init_device") - self.collective_rpc("load_model") + self.driver_worker.init_worker(all_kwargs=[kwargs]) + self.driver_worker.init_device() + self.driver_worker.load_model() def _distributed_args(self) -> tuple[str, int, int]: """Return (distributed_init_method, rank, local_rank).""" @@ -66,16 +60,14 @@ def max_concurrent_batches(self) -> int: def collective_rpc( self, - method: Union[str, Callable], - timeout: Optional[float] = None, + method: str | Callable, + timeout: float | None = None, args: tuple = (), - kwargs: Optional[dict] = None, + kwargs: dict | None = None, non_block: bool = False, ) -> list[Any]: if kwargs is None: kwargs = {} - if self.mm_receiver_cache is not None and method == "execute_model": - get_and_update_mm_cache(self.mm_receiver_cache, args) if not non_block: return [run_method(self.driver_worker, method, args, kwargs)] @@ -107,16 +99,12 @@ def reinitialize_distributed( == ReconfigureRankType.SHUTDOWN_CURRENT_RANK ): self.shutdown() - return def shutdown(self) -> None: if worker := self.driver_worker: worker.shutdown() -UniProcExecutorAsync = UniProcExecutor - - class ExecutorWithExternalLauncher(UniProcExecutor): """An executor that uses external launchers to launch engines, specially designed for torchrun-compatible launchers, for @@ -134,8 +122,6 @@ class ExecutorWithExternalLauncher(UniProcExecutor): and they don't need to synchronize the states with each other. """ - uses_ray: bool = False - def _init_executor(self) -> None: """Initialize the worker and load the model.""" if envs.VLLM_USE_V1: @@ -158,22 +144,12 @@ def _distributed_args(self) -> tuple[str, int, int]: local_rank = int(os.environ["LOCAL_RANK"]) return distributed_init_method, rank, local_rank - def determine_num_available_blocks(self) -> tuple[int, int]: - """ - Determine the number of available KV blocks. - Add an additional all_reduce to get the min across all ranks. - Note that even if we have the same `gpu_memory_utilization` and - `swap_space`, the available memory in every rank might still - differ because NCCL can take different amounts of memory in - different ranks. Therefore, it is necessary to test if all ranks - agree on the same KV cache configuration. - """ - a, b = super().determine_num_available_blocks() + def determine_available_memory(self) -> list[int]: # in bytes + # we need to get the min across all ranks. + memory = super().determine_available_memory() from vllm.distributed.parallel_state import get_world_group cpu_group = get_world_group().cpu_group - a_tensor = torch.tensor([a], device="cpu", dtype=torch.int64) - b_tensor = torch.tensor([b], device="cpu", dtype=torch.int64) - dist.all_reduce(a_tensor, group=cpu_group, op=dist.ReduceOp.MIN) - dist.all_reduce(b_tensor, group=cpu_group, op=dist.ReduceOp.MIN) - return a_tensor.item(), b_tensor.item() + memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64) + dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN) + return [memory_tensor.item()] diff --git a/vllm/v1/executor/utils.py b/vllm/v1/executor/utils.py deleted file mode 100644 index 884068a43882..000000000000 --- a/vllm/v1/executor/utils.py +++ /dev/null @@ -1,24 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.multimodal.cache import ShmObjectStoreReceiverCache -from vllm.v1.core.sched.output import SchedulerOutput - - -def get_and_update_mm_cache( - receiver_cache: ShmObjectStoreReceiverCache, - args: tuple[SchedulerOutput], -) -> None: - """ - For each MultiModalKwargsItem in SchedulerOutput, fetch from shared memory - cache as needed. - - Args: - receiver_cache: The receiver cache to update. - args: According to the collective_rpc call of execute_model method in - executor, args is a tuple of only one SchedulerOutput element. - """ - scheduler_output = args[0] - for request_data in scheduler_output.scheduled_new_reqs: - request_data.mm_features = receiver_cache.get_and_update_features( - request_data.mm_features - ) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 9c28eb92c17a..392519f8fa9a 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -4,14 +4,14 @@ import copy from dataclasses import dataclass, fields from math import prod -from typing import Optional import torch from typing_extensions import Self from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import cdiv, get_dtype_size +from vllm.utils import cdiv +from vllm.utils.torch_utils import get_dtype_size logger = init_logger(__name__) @@ -74,8 +74,8 @@ def page_size_bytes(self) -> int: @dataclass(frozen=True) class FullAttentionSpec(AttentionSpec): - sliding_window: Optional[int] = None - attention_chunk_size: Optional[int] = None + sliding_window: int | None = None + attention_chunk_size: int | None = None """ When hybrid allocator is disabled and the model contains both full attention layers and sliding window attention layers, sliding @@ -96,7 +96,7 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: return cdiv(max_model_len, self.block_size) * self.page_size_bytes @classmethod - def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]: + def merge_window_sizes(cls, window_sizes: set[int]) -> int | None: if len(window_sizes) == 0: return None elif len(window_sizes) == 1: @@ -154,7 +154,7 @@ def merge(cls, specs: list[Self]) -> Self: @dataclass(frozen=True) class MLAAttentionSpec(FullAttentionSpec): # TODO(Lucas/Chen): less hacky way to do this - cache_dtype_str: Optional[str] = None + cache_dtype_str: str | None = None @property def page_size_bytes(self) -> int: @@ -237,7 +237,7 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: class MambaSpec(KVCacheSpec): shapes: tuple[tuple[int, ...], ...] dtypes: tuple[torch.dtype] - page_size_padded: Optional[int] = None + page_size_padded: int | None = None mamba_type: str = "mamba2" num_speculative_blocks: int = 0 @@ -342,7 +342,7 @@ def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool: ) @classmethod - def from_specs(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> Optional[Self]: + def from_specs(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> Self | None: """ Return a SameTypeKVCacheSpecs object if all layers have the same type of KV cache spec. Return None if not. diff --git a/vllm/v1/kv_offload/abstract.py b/vllm/v1/kv_offload/abstract.py index ce2d0dffc0ff..c1d1cbebc175 100644 --- a/vllm/v1/kv_offload/abstract.py +++ b/vllm/v1/kv_offload/abstract.py @@ -30,7 +30,6 @@ from abc import ABC, abstractmethod from collections.abc import Iterable from dataclasses import dataclass -from typing import Optional from vllm.v1.core.kv_cache_utils import BlockHash @@ -122,7 +121,7 @@ def complete_load(self, block_hashes: Iterable[BlockHash]): @abstractmethod def prepare_store( self, block_hashes: Iterable[BlockHash] - ) -> Optional[PrepareStoreOutput]: + ) -> PrepareStoreOutput | None: """ Prepare the given blocks to be offloaded. The given blocks will be protected from eviction until diff --git a/vllm/v1/kv_offload/cpu.py b/vllm/v1/kv_offload/cpu.py index 0c1cf64a237c..250ed5e95af4 100644 --- a/vllm/v1/kv_offload/cpu.py +++ b/vllm/v1/kv_offload/cpu.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterator -from typing import Optional import torch @@ -29,10 +28,10 @@ def __init__(self, vllm_config: VllmConfig): self.num_cpu_blocks: int = num_cpu_blocks # scheduler-side - self._manager: Optional[OffloadingManager] = None + self._manager: OffloadingManager | None = None # worker-side - self._handler: Optional[OffloadingHandler] = None + self._handler: OffloadingHandler | None = None def get_manager(self) -> OffloadingManager: if not self._manager: diff --git a/vllm/v1/kv_offload/factory.py b/vllm/v1/kv_offload/factory.py index e0a53460e840..b4d40cb48e1d 100644 --- a/vllm/v1/kv_offload/factory.py +++ b/vllm/v1/kv_offload/factory.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib -from typing import TYPE_CHECKING, Callable +from collections.abc import Callable +from typing import TYPE_CHECKING from vllm.logger import init_logger from vllm.v1.kv_offload.spec import OffloadingSpec diff --git a/vllm/v1/kv_offload/lru_manager.py b/vllm/v1/kv_offload/lru_manager.py index 36f5eb4a0abd..0a0111f88790 100644 --- a/vllm/v1/kv_offload/lru_manager.py +++ b/vllm/v1/kv_offload/lru_manager.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import OrderedDict from collections.abc import Iterable -from typing import Optional from vllm.v1.core.kv_cache_utils import BlockHash from vllm.v1.kv_offload.abstract import ( @@ -23,7 +22,7 @@ def __init__(self, backend: Backend, enable_events: bool = False): self.backend: Backend = backend # block_hash -> BlockStatus self.blocks: OrderedDict[BlockHash, BlockStatus] = OrderedDict() - self.events: Optional[list[OffloadingEvent]] = [] if enable_events else None + self.events: list[OffloadingEvent] | None = [] if enable_events else None def lookup(self, block_hashes: Iterable[BlockHash]) -> int: hit_count = 0 @@ -57,7 +56,7 @@ def complete_load(self, block_hashes: Iterable[BlockHash]): def prepare_store( self, block_hashes: Iterable[BlockHash] - ) -> Optional[PrepareStoreOutput]: + ) -> PrepareStoreOutput | None: # filter out blocks that are already stored block_hashes_to_store = [ block_hash for block_hash in block_hashes if block_hash not in self.blocks diff --git a/vllm/v1/kv_offload/worker/cpu_gpu.py b/vllm/v1/kv_offload/worker/cpu_gpu.py index eb7117a400b9..646f9d0d7542 100644 --- a/vllm/v1/kv_offload/worker/cpu_gpu.py +++ b/vllm/v1/kv_offload/worker/cpu_gpu.py @@ -7,7 +7,7 @@ from vllm import _custom_ops as ops from vllm.attention import AttentionBackend from vllm.logger import init_logger -from vllm.utils import is_pin_memory_available +from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec from vllm.v1.kv_offload.worker.worker import ( OffloadingHandler, diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 541af7af1725..c5d7885eefb7 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -4,22 +4,30 @@ import logging import time from abc import ABC, abstractmethod -from typing import Callable, Optional, Union +from collections.abc import Callable +from typing import TypeAlias -import prometheus_client +from prometheus_client import Counter, Gauge, Histogram from vllm.config import SupportsMetricsInfo, VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging from vllm.logger import init_logger -from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics +from vllm.plugins import load_plugins_by_group from vllm.v1.engine import FinishReason from vllm.v1.metrics.prometheus import unregister_vllm_metrics -from vllm.v1.metrics.stats import IterationStats, SchedulerStats +from vllm.v1.metrics.stats import ( + CachingMetrics, + IterationStats, + MultiModalCacheStats, + SchedulerStats, +) from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm logger = init_logger(__name__) -StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"] +PerEngineStatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"] +AggregateStatLoggerFactory = type["AggregateStatLoggerBase"] +StatLoggerFactory = AggregateStatLoggerFactory | PerEngineStatLoggerFactory class StatLoggerBase(ABC): @@ -36,8 +44,9 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): ... @abstractmethod def record( self, - scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats], + scheduler_stats: SchedulerStats | None, + iteration_stats: IterationStats | None, + mm_cache_stats: MultiModalCacheStats | None = None, engine_idx: int = 0, ): ... @@ -48,20 +57,52 @@ def log(self): # noqa pass +def load_stat_logger_plugin_factories() -> list[StatLoggerFactory]: + factories: list[StatLoggerFactory] = [] + + for name, plugin_class in load_plugins_by_group("vllm.stat_logger_plugins").items(): + if not isinstance(plugin_class, type) or not issubclass( + plugin_class, StatLoggerBase + ): + raise TypeError( + f"Stat logger plugin {name!r} must be a subclass of " + f"StatLoggerBase (got {plugin_class!r})." + ) + + factories.append(plugin_class) + + return factories + + +class AggregateStatLoggerBase(StatLoggerBase): + """Abstract base class for loggers that + aggregate across multiple DP engines.""" + + @abstractmethod + def __init__(self, vllm_config: VllmConfig, engine_indexes: list[int]): ... + + class LoggingStatLogger(StatLoggerBase): def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.engine_index = engine_index self.vllm_config = vllm_config self._reset(time.monotonic()) + self.last_scheduler_stats = SchedulerStats() - # Prefix cache metrics. This cannot be reset. + + # Caching metrics. This cannot be reset. # TODO: Make the interval configurable. - self.prefix_caching_metrics = PrefixCachingMetrics() + self.prefix_caching_metrics = CachingMetrics() + self.connector_prefix_caching_metrics = CachingMetrics() + self.mm_caching_metrics = CachingMetrics() + self.spec_decoding_logging = SpecDecodingLogging() kv_tranfer_config = self.vllm_config.kv_transfer_config self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config) self.last_prompt_throughput: float = 0.0 self.last_generation_throughput: float = 0.0 + self.engine_is_idle = False + self.aggregated = False def _reset(self, now): self.last_log_time = now @@ -82,10 +123,15 @@ def _get_throughput(self, tracked_stats: int, now: float) -> float: return 0.0 return float(tracked_stats / delta_time) + @property + def log_prefix(self): + return "Engine {:03d}: ".format(self.engine_index) + def record( self, - scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats], + scheduler_stats: SchedulerStats | None, + iteration_stats: IterationStats | None, + mm_cache_stats: MultiModalCacheStats | None = None, engine_idx: int = 0, ): """Log Stats to standard output.""" @@ -95,57 +141,81 @@ def record( if scheduler_stats is not None: self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) + if scheduler_stats.connector_prefix_cache_stats is not None: + self.connector_prefix_caching_metrics.observe( + scheduler_stats.connector_prefix_cache_stats + ) + if scheduler_stats.spec_decoding_stats is not None: self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats) if kv_connector_stats := scheduler_stats.kv_connector_stats: self.kv_connector_logging.observe(kv_connector_stats) - self.last_scheduler_stats = scheduler_stats + if not self.aggregated: + self.last_scheduler_stats = scheduler_stats + if mm_cache_stats: + self.mm_caching_metrics.observe(mm_cache_stats) - def log(self): + def _update_stats(self): now = time.monotonic() prompt_throughput = self._get_throughput(self.num_prompt_tokens, now) generation_throughput = self._get_throughput(self.num_generation_tokens, now) self._reset(now) - - scheduler_stats = self.last_scheduler_stats - - log_fn = logger.info - if not any( + self.engine_is_idle = not any( ( prompt_throughput, generation_throughput, self.last_prompt_throughput, self.last_generation_throughput, ) - ): - # Avoid log noise on an idle production system - log_fn = logger.debug + ) self.last_generation_throughput = generation_throughput self.last_prompt_throughput = prompt_throughput + def aggregate_scheduler_stats(self): + # noop for per engine loggers + return + + def log(self): + self._update_stats() + self.aggregate_scheduler_stats() + # Avoid log noise on an idle production system + log_fn = logger.debug if self.engine_is_idle else logger.info # Format and print output. - log_fn( - "Engine %03d: " - "Avg prompt throughput: %.1f tokens/s, " - "Avg generation throughput: %.1f tokens/s, " - "Running: %d reqs, Waiting: %d reqs, " - "GPU KV cache usage: %.1f%%, " + log_parts = [ + "Avg prompt throughput: %.1f tokens/s", + "Avg generation throughput: %.1f tokens/s", + "Running: %d reqs", + "Waiting: %d reqs", + "GPU KV cache usage: %.1f%%", "Prefix cache hit rate: %.1f%%", - self.engine_index, - prompt_throughput, - generation_throughput, - scheduler_stats.num_running_reqs, - scheduler_stats.num_waiting_reqs, - scheduler_stats.kv_cache_usage * 100, + ] + log_args = [ + self.last_prompt_throughput, + self.last_generation_throughput, + self.last_scheduler_stats.num_running_reqs, + self.last_scheduler_stats.num_waiting_reqs, + self.last_scheduler_stats.kv_cache_usage * 100, self.prefix_caching_metrics.hit_rate * 100, + ] + if not self.connector_prefix_caching_metrics.empty: + log_parts.append("External prefix cache hit rate: %.1f%%") + log_args.append(self.connector_prefix_caching_metrics.hit_rate * 100) + if not self.mm_caching_metrics.empty: + log_parts.append("MM cache hit rate: %.1f%%") + log_args.append(self.mm_caching_metrics.hit_rate * 100) + + log_fn( + self.log_prefix + ", ".join(log_parts), + *log_args, ) + self.spec_decoding_logging.log(log_fn=log_fn) self.kv_connector_logging.log(log_fn=log_fn) def log_engine_initialized(self): if self.vllm_config.cache_config.num_gpu_blocks: - logger.info( + logger.debug( "Engine %03d: vllm cache_config_info with initialization " "after num_gpu_blocks is: %d", self.engine_index, @@ -153,17 +223,125 @@ def log_engine_initialized(self): ) -class PrometheusStatLogger(StatLoggerBase): - _gauge_cls = prometheus_client.Gauge - _counter_cls = prometheus_client.Counter - _histogram_cls = prometheus_client.Histogram +class AggregatedLoggingStatLogger(LoggingStatLogger, AggregateStatLoggerBase): + def __init__( + self, + vllm_config: VllmConfig, + engine_indexes: list[int], + ): + self.engine_indexes = engine_indexes + self.last_scheduler_stats_dict: dict[int, SchedulerStats] = { + idx: SchedulerStats() for idx in self.engine_indexes + } + LoggingStatLogger.__init__(self, vllm_config, engine_index=-1) + self.aggregated = True + + @property + def log_prefix(self): + return "{} Engines Aggregated: ".format(len(self.engine_indexes)) + + def record( + self, + scheduler_stats: SchedulerStats | None, + iteration_stats: IterationStats | None, + mm_cache_stats: MultiModalCacheStats | None = None, + engine_idx: int = 0, + ): + if engine_idx not in self.engine_indexes: + logger.warning("Unexpected engine_idx: %d", engine_idx) + return + LoggingStatLogger.record( + self, + scheduler_stats, + iteration_stats, + mm_cache_stats=mm_cache_stats, + engine_idx=engine_idx, + ) + if scheduler_stats is not None: + self.last_scheduler_stats_dict[engine_idx] = scheduler_stats + + def aggregate_scheduler_stats(self): + self.last_scheduler_stats = SchedulerStats() + for last_scheduler_stats in self.last_scheduler_stats_dict.values(): + self.last_scheduler_stats.num_waiting_reqs += ( + last_scheduler_stats.num_waiting_reqs + ) + self.last_scheduler_stats.num_running_reqs += ( + last_scheduler_stats.num_running_reqs + ) + self.last_scheduler_stats.num_corrupted_reqs += ( + last_scheduler_stats.num_corrupted_reqs + ) + self.last_scheduler_stats.kv_cache_usage += ( + last_scheduler_stats.kv_cache_usage + ) + self.last_scheduler_stats.kv_cache_usage /= len(self.last_scheduler_stats_dict) + + def log(self): + LoggingStatLogger.log(self) + + def log_engine_initialized(self): + if self.vllm_config.cache_config.num_gpu_blocks: + logger.info( + "%d Engines: vllm cache_config_info with initialization " + "after num_gpu_blocks is: %d", + len(self.engine_indexes), + self.vllm_config.cache_config.num_gpu_blocks, + ) + + +class PerEngineStatLoggerAdapter(AggregateStatLoggerBase): + def __init__( + self, + vllm_config: VllmConfig, + engine_indexes: list[int], + per_engine_stat_logger_factory: PerEngineStatLoggerFactory, + ) -> None: + self.per_engine_stat_loggers = {} + self.engine_indexes = engine_indexes + for engine_index in engine_indexes: + self.per_engine_stat_loggers[engine_index] = per_engine_stat_logger_factory( + vllm_config, engine_index + ) + + def record( + self, + scheduler_stats: SchedulerStats | None, + iteration_stats: IterationStats | None, + mm_cache_stats: MultiModalCacheStats | None = None, + engine_idx: int = 0, + ): + if engine_idx not in self.per_engine_stat_loggers: + logger.warning("Unexpected engine_idx: %d", engine_idx) + return + self.per_engine_stat_loggers[engine_idx].record( + scheduler_stats, + iteration_stats, + mm_cache_stats=mm_cache_stats, + engine_idx=engine_idx, + ) + + def log(self): + for per_engine_stat_logger in self.per_engine_stat_loggers.values(): + per_engine_stat_logger.log() + + def log_engine_initialized(self): + for per_engine_stat_logger in self.per_engine_stat_loggers.values(): + per_engine_stat_logger.log_engine_initialized() + + +class PrometheusStatLogger(AggregateStatLoggerBase): + _gauge_cls = Gauge + _counter_cls = Counter + _histogram_cls = Histogram _spec_decoding_cls = SpecDecodingProm def __init__( - self, vllm_config: VllmConfig, engine_indexes: Optional[list[int]] = None + self, vllm_config: VllmConfig, engine_indexes: list[int] | None = None ): if engine_indexes is None: engine_indexes = [0] + self.engine_indexes = engine_indexes unregister_vllm_metrics() @@ -288,6 +466,60 @@ def __init__( counter_prefix_cache_hits, engine_indexes, model_name ) + # + # External - KV connector prefix cache + # + + counter_connector_prefix_cache_queries = self._counter_cls( + name="vllm:external_prefix_cache_queries", + documentation=( + "External prefix cache queries from KV connector " + "cross-instance cache sharing, in terms of number of queried tokens." + ), + labelnames=labelnames, + ) + self.counter_connector_prefix_cache_queries = make_per_engine( + counter_connector_prefix_cache_queries, engine_indexes, model_name + ) + + counter_connector_prefix_cache_hits = self._counter_cls( + name="vllm:external_prefix_cache_hits", + documentation=( + "External prefix cache hits from KV connector " + "cross-instance cache sharing, in terms of number of cached tokens." + ), + labelnames=labelnames, + ) + self.counter_connector_prefix_cache_hits = make_per_engine( + counter_connector_prefix_cache_hits, engine_indexes, model_name + ) + + # + # Multi-modal cache + # + + counter_mm_cache_queries = self._counter_cls( + name="vllm:mm_cache_queries", + documentation=( + "Multi-modal cache queries, in terms of number of queried items." + ), + labelnames=labelnames, + ) + self.counter_mm_cache_queries = make_per_engine( + counter_mm_cache_queries, engine_indexes, model_name + ) + + counter_mm_cache_hits = self._counter_cls( + name="vllm:mm_cache_hits", + documentation=( + "Multi-modal cache hits, in terms of number of cached items." + ), + labelnames=labelnames, + ) + self.counter_mm_cache_hits = make_per_engine( + counter_mm_cache_hits, engine_indexes, model_name + ) + # # Counters # @@ -318,9 +550,7 @@ def __init__( counter_generation_tokens, engine_indexes, model_name ) - self.counter_request_success: dict[ - FinishReason, dict[int, prometheus_client.Counter] - ] = {} + self.counter_request_success: dict[FinishReason, dict[int, Counter]] = {} counter_request_success_base = self._counter_cls( name="vllm:request_success", documentation="Count of successfully processed requests.", @@ -610,7 +840,7 @@ def __init__( # TODO: This metric might be incorrect in case of using multiple # api_server counts which uses prometheus mp. - self.gauge_lora_info: Optional[prometheus_client.Gauge] = None + self.gauge_lora_info: Gauge | None = None if vllm_config.lora_config is not None: if len(self.engine_indexes) > 1: raise NotImplementedError("LoRA in DP mode is not supported yet.") @@ -655,8 +885,9 @@ def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): def record( self, - scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats], + scheduler_stats: SchedulerStats | None, + iteration_stats: IterationStats | None, + mm_cache_stats: MultiModalCacheStats | None = None, engine_idx: int = 0, ): """Log to prometheus.""" @@ -689,11 +920,23 @@ def record( scheduler_stats.prefix_cache_stats.hits ) + if scheduler_stats.connector_prefix_cache_stats is not None: + self.counter_connector_prefix_cache_queries[engine_idx].inc( + scheduler_stats.connector_prefix_cache_stats.queries + ) + self.counter_connector_prefix_cache_hits[engine_idx].inc( + scheduler_stats.connector_prefix_cache_stats.hits + ) + if scheduler_stats.spec_decoding_stats is not None: self.spec_decoding_prom.observe( scheduler_stats.spec_decoding_stats, engine_idx ) + if mm_cache_stats is not None: + self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries) + self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits) + if iteration_stats is None: return @@ -771,11 +1014,7 @@ def log_engine_initialized(self): self.log_metrics_info("cache_config", self.vllm_config.cache_config) -PromMetric = Union[ - prometheus_client.Gauge, - prometheus_client.Counter, - prometheus_client.Histogram, -] +PromMetric: TypeAlias = Gauge | Counter | Histogram def make_per_engine( @@ -827,17 +1066,17 @@ class StatLoggerManager: def __init__( self, vllm_config: VllmConfig, - engine_idxs: Optional[list[int]] = None, - custom_stat_loggers: Optional[list[StatLoggerFactory]] = None, + engine_idxs: list[int] | None = None, + custom_stat_loggers: list[StatLoggerFactory] | None = None, enable_default_loggers: bool = True, + aggregate_engine_logging: bool = False, client_count: int = 1, ): - self.engine_idxs = engine_idxs if engine_idxs else [0] - - factories: list[StatLoggerFactory] = [] + self.engine_indexes = engine_idxs if engine_idxs else [0] + self.stat_loggers: list[AggregateStatLoggerBase] = [] + stat_logger_factories: list[StatLoggerFactory] = [] if custom_stat_loggers is not None: - factories.extend(custom_stat_loggers) - + stat_logger_factories.extend(custom_stat_loggers) if enable_default_loggers and logger.isEnabledFor(logging.INFO): if client_count > 1: logger.warning( @@ -845,51 +1084,57 @@ def __init__( "disabling stats logging to avoid incomplete stats." ) else: - factories.append(LoggingStatLogger) - - # engine_idx: StatLogger - self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {} - prometheus_factory = PrometheusStatLogger - for engine_idx in self.engine_idxs: - loggers: list[StatLoggerBase] = [] - for logger_factory in factories: - # If we get a custom prometheus logger, use that - # instead. This is typically used for the ray case. - if isinstance(logger_factory, type) and issubclass( - logger_factory, PrometheusStatLogger - ): - prometheus_factory = logger_factory - continue - loggers.append(logger_factory(vllm_config, engine_idx)) # type: ignore - self.per_engine_logger_dict[engine_idx] = loggers - - # For Prometheus, need to share the metrics between EngineCores. - # Each EngineCore's metrics are expressed as a unique label. - self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs) + default_logger_factory = ( + AggregatedLoggingStatLogger + if aggregate_engine_logging + else LoggingStatLogger + ) + stat_logger_factories.append(default_logger_factory) + custom_prometheus_logger: bool = False + for stat_logger_factory in stat_logger_factories: + if isinstance(stat_logger_factory, type) and issubclass( + stat_logger_factory, AggregateStatLoggerBase + ): + global_stat_logger = stat_logger_factory( + vllm_config=vllm_config, + engine_indexes=self.engine_indexes, + ) + if isinstance(global_stat_logger, PrometheusStatLogger): + custom_prometheus_logger = True + else: + # per engine logger + global_stat_logger = PerEngineStatLoggerAdapter( + vllm_config=vllm_config, + engine_indexes=self.engine_indexes, + per_engine_stat_logger_factory=stat_logger_factory, # type: ignore[arg-type] + ) + self.stat_loggers.append(global_stat_logger) + if not custom_prometheus_logger: + self.stat_loggers.append( + PrometheusStatLogger(vllm_config, self.engine_indexes) + ) def record( self, - scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats], - engine_idx: Optional[int] = None, + scheduler_stats: SchedulerStats | None, + iteration_stats: IterationStats | None, + mm_cache_stats: MultiModalCacheStats | None = None, + engine_idx: int | None = None, ): if engine_idx is None: engine_idx = 0 - - per_engine_loggers = self.per_engine_logger_dict[engine_idx] - for logger in per_engine_loggers: - logger.record(scheduler_stats, iteration_stats, engine_idx) - - self.prometheus_logger.record(scheduler_stats, iteration_stats, engine_idx) + for logger in self.stat_loggers: + logger.record( + scheduler_stats, + iteration_stats, + mm_cache_stats=mm_cache_stats, + engine_idx=engine_idx, + ) def log(self): - for per_engine_loggers in self.per_engine_logger_dict.values(): - for logger in per_engine_loggers: - logger.log() + for logger in self.stat_loggers: + logger.log() def log_engine_initialized(self): - self.prometheus_logger.log_engine_initialized() - - for per_engine_loggers in self.per_engine_logger_dict.values(): - for logger in per_engine_loggers: - logger.log_engine_initialized() + for agg_logger in self.stat_loggers: + agg_logger.log_engine_initialized() diff --git a/vllm/v1/metrics/prometheus.py b/vllm/v1/metrics/prometheus.py index 5823737968f9..1eacb785aa84 100644 --- a/vllm/v1/metrics/prometheus.py +++ b/vllm/v1/metrics/prometheus.py @@ -3,7 +3,6 @@ import os import tempfile -from typing import Optional from prometheus_client import REGISTRY, CollectorRegistry, multiprocess @@ -12,7 +11,7 @@ logger = init_logger(__name__) # Global temporary directory for prometheus multiprocessing -_prometheus_multiproc_dir: Optional[tempfile.TemporaryDirectory] = None +_prometheus_multiproc_dir: tempfile.TemporaryDirectory | None = None def setup_multiprocess_prometheus(): diff --git a/vllm/v1/metrics/ray_wrappers.py b/vllm/v1/metrics/ray_wrappers.py index a6fe2062f70c..b845852a0c0d 100644 --- a/vllm/v1/metrics/ray_wrappers.py +++ b/vllm/v1/metrics/ray_wrappers.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time -from typing import Optional, Union from vllm.v1.metrics.loggers import PrometheusStatLogger from vllm.v1.spec_decode.metrics import SpecDecodingProm @@ -63,9 +62,9 @@ class RayGaugeWrapper(RayPrometheusMetric): def __init__( self, name: str, - documentation: Optional[str] = "", - labelnames: Optional[list[str]] = None, - multiprocess_mode: Optional[str] = "", + documentation: str | None = "", + labelnames: list[str] | None = None, + multiprocess_mode: str | None = "", ): # All Ray metrics are keyed by WorkerId, so multiprocess modes like # "mostrecent", "all", "sum" do not apply. This logic can be manually @@ -77,7 +76,7 @@ def __init__( name=name, description=documentation, tag_keys=labelnames_tuple ) - def set(self, value: Union[int, float]): + def set(self, value: int | float): return self.metric.set(value) def set_to_current_time(self): @@ -92,8 +91,8 @@ class RayCounterWrapper(RayPrometheusMetric): def __init__( self, name: str, - documentation: Optional[str] = "", - labelnames: Optional[list[str]] = None, + documentation: str | None = "", + labelnames: list[str] | None = None, ): labelnames_tuple = tuple(labelnames) if labelnames else None name = self._get_sanitized_opentelemetry_name(name) @@ -101,7 +100,7 @@ def __init__( name=name, description=documentation, tag_keys=labelnames_tuple ) - def inc(self, value: Union[int, float] = 1.0): + def inc(self, value: int | float = 1.0): if value == 0: return return self.metric.inc(value) @@ -114,9 +113,9 @@ class RayHistogramWrapper(RayPrometheusMetric): def __init__( self, name: str, - documentation: Optional[str] = "", - labelnames: Optional[list[str]] = None, - buckets: Optional[list[float]] = None, + documentation: str | None = "", + labelnames: list[str] | None = None, + buckets: list[float] | None = None, ): labelnames_tuple = tuple(labelnames) if labelnames else None name = self._get_sanitized_opentelemetry_name(name) @@ -128,7 +127,7 @@ def __init__( boundaries=boundaries, ) - def observe(self, value: Union[int, float]): + def observe(self, value: int | float): return self.metric.observe(value) diff --git a/vllm/v1/metrics/reader.py b/vllm/v1/metrics/reader.py index 5d50fa9461d0..48c88e5b61cb 100644 --- a/vllm/v1/metrics/reader.py +++ b/vllm/v1/metrics/reader.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional from prometheus_client import REGISTRY from prometheus_client import Metric as PromMetric @@ -144,7 +143,7 @@ def get_metrics_snapshot() -> list[Metric]: return collected -def _get_samples(metric: PromMetric, suffix: Optional[str] = None) -> list[Sample]: +def _get_samples(metric: PromMetric, suffix: str | None = None) -> list[Sample]: name = (metric.name + suffix) if suffix is not None else metric.name return [s for s in metric.samples if s.name == name] diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 5564718d5165..7868141d1b1d 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time +from collections import deque from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from vllm.v1.spec_decode.metrics import SpecDecodingStats @@ -13,24 +14,140 @@ @dataclass -class PrefixCacheStats: - """Stores prefix cache hit statistics.""" +class BaseCacheStats: + """Stores cache hit statistics.""" - # Whether reset_prefix_cache was invoked. reset: bool = False - # The number of new requests in this update. + """Whether the cache was reset.""" + requests: int = 0 - # The number of queries in these requests. Note that "queries" here - # means the number of tokens that were queried from the cache. + """The number of requests in this update.""" + queries: int = 0 - # The number of hits in these requests. + """The number of queries in these requests.""" + hits: int = 0 - # The number of previously preempted requests in this update. + """The number of hits in these requests.""" + + +class CachingMetrics: + """Metrics for caching with a hit rate of the most recent N requests. + Args: + interval: The number of the most recent requests to aggregate. + Defaults to 1000. + """ + + def __init__(self, max_recent_requests: int = 1000) -> None: + super().__init__() + + self.max_recent_requests = max_recent_requests + # The current aggregated values. + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 + + # A deque of (requests, queries, hits) for the most recent requests. + self.query_queue = deque[tuple[int, int, int]]() + + def observe(self, stats: BaseCacheStats): + """Observe the prefix caching for a set of requests. + + This function is called with information gathered when new requests + are being scheduled and are looking for computed blocks. + + When there are more than `max_recent_requests` requests, the oldest set + of requests are removed from the metrics. + + Args: + stats: The prefix cache stats. + """ + # reset_prefix_cache was invoked before the current update. + # Reset the metrics before aggregating the current stats. + if stats.reset: + self.reset() + + # DO NOT appending empty stats to avoid helpful info get kicked out + # due to sliding window. + if stats.requests == 0: + return + + # Update the metrics. + self.query_queue.append((stats.requests, stats.queries, stats.hits)) + self.aggregated_requests += stats.requests + self.aggregated_query_total += stats.queries + self.aggregated_query_hit += stats.hits + + # Remove the oldest stats until number of requests does not exceed + # the limit. + # NOTE: We preserve the latest added stats regardless. + while ( + len(self.query_queue) > 1 + and self.aggregated_requests > self.max_recent_requests + ): + old_requests, old_queries, old_hits = self.query_queue.popleft() + self.aggregated_requests -= old_requests + self.aggregated_query_total -= old_queries + self.aggregated_query_hit -= old_hits + + def reset(self): + """Reset the metrics.""" + self.aggregated_requests = 0 + self.aggregated_query_total = 0 + self.aggregated_query_hit = 0 + self.query_queue.clear() + + @property + def empty(self) -> bool: + """Return true if no requests have been observed.""" + return self.aggregated_requests == 0 + + @property + def hit_rate(self) -> float: + """Calculate the hit rate for the past N requests.""" + if self.aggregated_query_total == 0: + return 0.0 + return self.aggregated_query_hit / self.aggregated_query_total + + +@dataclass +class PrefixCacheStats(BaseCacheStats): + """ + Stores prefix cache hit statistics. + - `reset`: Whether `reset_prefix_cache` was invoked. + - `queries`: Refers to the number of tokens that were queried. + """ + preempted_requests: int = 0 - # The `queries` number for preempted requests. + """The number of previously preempted requests in this update.""" + preempted_queries: int = 0 - # The `hits` number for preempted requests. + """The `queries` number for preempted requests.""" + preempted_hits: int = 0 + """The `hits` number for preempted requests.""" + + def record(self, num_tokens: int, num_hits: int, preempted: bool) -> None: + """Aggregate request information into the stats.""" + if preempted: + # Previously preempted request + self.preempted_requests += 1 + self.preempted_queries += num_tokens + self.preempted_hits += num_hits + else: + # New request + self.requests += 1 + self.queries += num_tokens + self.hits += num_hits + + +@dataclass +class MultiModalCacheStats(BaseCacheStats): + """ + Stores multi-modal cache hit statistics. + - `reset`: Whether `reset_mm_cache` was invoked. + - `queries`: Refers to the number of multi-modal data items + that were queried. + """ @dataclass @@ -47,9 +164,10 @@ class SchedulerStats: kv_cache_usage: float = 0.0 prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats) + connector_prefix_cache_stats: PrefixCacheStats | None = None - spec_decoding_stats: Optional[SpecDecodingStats] = None - kv_connector_stats: Optional[dict[str, Any]] = None + spec_decoding_stats: SpecDecodingStats | None = None + kv_connector_stats: dict[str, Any] | None = None num_corrupted_reqs: int = 0 @@ -87,7 +205,7 @@ class FinishedRequestStats: e2e_latency: float = 0.0 num_prompt_tokens: int = 0 num_generation_tokens: int = 0 - max_tokens_param: Optional[int] = None + max_tokens_param: int | None = None queued_time: float = 0.0 prefill_time: float = 0.0 inference_time: float = 0.0 @@ -126,7 +244,7 @@ def update_from_output( is_prefilling: bool, prompt_len: int, req_stats: RequestStateStats, - lora_stats: Optional[LoRAStats], + lora_stats: LoRAStats | None, ): num_new_generation_tokens = len(output.new_token_ids) @@ -161,7 +279,7 @@ def update_from_events( events: list["EngineCoreEvent"], is_prefilling: bool, req_stats: RequestStateStats, - lora_stats: Optional[LoRAStats], + lora_stats: LoRAStats | None, ): # Avoid circular dependency from vllm.v1.engine import EngineCoreEventType @@ -183,7 +301,7 @@ def update_from_finished_request( self, finish_reason: "FinishReason", num_prompt_tokens: int, - max_tokens_param: Optional[int], + max_tokens_param: int | None, req_stats: RequestStateStats, ): e2e_latency = self._time_since(req_stats.arrival_time) @@ -231,7 +349,7 @@ class LoRARequestStates: def __init__(self): self.lora_name_to_stats: dict[str, LoRAStats] = {} - def get_stats(self, req_state: "RequestState") -> Optional[LoRAStats]: + def get_stats(self, req_state: "RequestState") -> LoRAStats | None: if req_state.lora_name is None: return None if req_state.lora_name not in self.lora_name_to_stats: @@ -258,20 +376,20 @@ def abort_request(self, req_state: "RequestState"): # Break the pattern for this lifecycle methods so we can # call this from IterationStats.update_from_events() @staticmethod - def scheduled_request(lora_stats: Optional[LoRAStats], request_id: str): + def scheduled_request(lora_stats: LoRAStats | None, request_id: str): if lora_stats is None: return lora_stats.waiting_requests.remove(request_id) lora_stats.running_requests.add(request_id) @staticmethod - def preempted_request(lora_stats: Optional[LoRAStats], request_id: str): + def preempted_request(lora_stats: LoRAStats | None, request_id: str): if lora_stats is None: return lora_stats.running_requests.remove(request_id) lora_stats.waiting_requests.add(request_id) - def update_iteration_stats(self, iteration_stats: Optional[IterationStats]): + def update_iteration_stats(self, iteration_stats: IterationStats | None): if iteration_stats is None: return for lora_name, stats in self.lora_name_to_stats.items(): diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index d647b207575c..10f97576b60a 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -3,43 +3,60 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import TYPE_CHECKING, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, NamedTuple import torch if TYPE_CHECKING: from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats +else: + KVConnectorStats = object class LogprobsLists(NamedTuple): - # [num_reqs, max_num_logprobs + 1] + # [num_reqs x num_generated_tokens, max_num_logprobs + 1] logprob_token_ids: list[list[int]] - # [num_reqs, max_num_logprobs + 1] + # [num_reqs x num_generated_tokens, max_num_logprobs + 1] logprobs: list[list[float]] - # [num_reqs] + # [num_reqs x num_generated_tokens] sampled_token_ranks: list[int] - - def slice(self, start: int, end: int): + # [num_reqs] + # Used for slicing the logprobs in cases like speculative + # decoding where the number of generated tokens may be + # different for each request. + cu_num_generated_tokens: list[int] | None = None + + def slice(self, start_req_idx: int, end_req_idx: int): + if self.cu_num_generated_tokens: + start = self.cu_num_generated_tokens[start_req_idx] + end = self.cu_num_generated_tokens[end_req_idx] + else: + start = start_req_idx + end = end_req_idx return LogprobsLists( self.logprob_token_ids[start:end], self.logprobs[start:end], self.sampled_token_ranks[start:end], + self.cu_num_generated_tokens[start_req_idx:end_req_idx] + if self.cu_num_generated_tokens + else None, ) class LogprobsTensors(NamedTuple): - # [num_reqs, max_num_logprobs + 1] + # [num_reqs x num_generated_tokens, max_num_logprobs + 1] logprob_token_ids: torch.Tensor - # [num_reqs, max_num_logprobs + 1] + # [num_reqs x num_generated_tokens, max_num_logprobs + 1] logprobs: torch.Tensor - # [num_reqs] + # [num_reqs x num_generated_tokens] selected_token_ranks: torch.Tensor - def tolists(self): + def tolists(self, cu_num_generated_tokens: list[int] | None = None): return LogprobsLists( self.logprob_token_ids.tolist(), self.logprobs.tolist(), self.selected_token_ranks.tolist(), + cu_num_generated_tokens, ) @staticmethod @@ -64,7 +81,7 @@ def empty_cpu( # [num_reqs, <dynamic>] # The shape of each element depends on the pooler used -PoolerOutput = Union[torch.Tensor, list[torch.Tensor]] +PoolerOutput = torch.Tensor | list[torch.Tensor] @dataclass @@ -74,18 +91,24 @@ class SamplerOutput: # All requests are padded to max_num_generated_tokens. # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding. sampled_token_ids: torch.Tensor - logprobs_tensors: Optional[LogprobsTensors] + logprobs_tensors: LogprobsTensors | None @dataclass class KVConnectorOutput: # [req_ids] - finished_sending: Optional[set[str]] = None - finished_recving: Optional[set[str]] = None - kv_connector_stats: Optional["KVConnectorStats"] = None + finished_sending: set[str] | None = None + finished_recving: set[str] | None = None + kv_connector_stats: KVConnectorStats | None = None # IDs of externally computed KV blocks that failed to load. - # Requests referencing these blocks should be rescheduled to recompute them. + # Requests referencing these blocks should be rescheduled to recompute them invalid_block_ids: set[int] = field(default_factory=set) + # Configuration describing how many finished sending/receiving + # notifications should be expected for each request. This allows + # handshake-based connectors like Nixl to update the KVOutputAggregator. + # It captures a static setup info and should almost always remain constant + # for a given connector after discovery. Default value entails no change. + expected_finished_count: int = 0 def is_empty(self): return ( @@ -114,21 +137,21 @@ class ModelRunnerOutput: # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] # [num_reqs] - logprobs: Optional[LogprobsLists] + logprobs: LogprobsLists | None # req_id -> (token_ids, logprobs, ranks) # [prompt_len, num_prompt_logprobs] # [prompt_len, num_prompt_logprobs] # [prompt_len] - prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] + prompt_logprobs_dict: dict[str, LogprobsTensors | None] # [num_reqs, hidden_size] - pooler_output: list[Optional[torch.Tensor]] + pooler_output: list[torch.Tensor | None] - kv_connector_output: Optional[KVConnectorOutput] = None + kv_connector_output: KVConnectorOutput | None = None # req_id -> num_nans_in_logits - num_nans_in_logits: Optional[dict[str, int]] = None + num_nans_in_logits: dict[str, int] | None = None # ModelRunnerOutput wrapper for async scheduling. diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index 36ae5b40a313..9883ab8fb996 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional import torch from vllm.pooling_params import PoolingParams -from vllm.utils import is_pin_memory_available +from vllm.utils.platform_utils import is_pin_memory_available pin_memory = is_pin_memory_available() @@ -37,9 +36,9 @@ class PoolingMetadata: """Tensors for pooling.""" prompt_lens: torch.Tensor # CPU Tensor - prompt_token_ids: Optional[torch.Tensor] + prompt_token_ids: torch.Tensor | None pooling_params: list[PoolingParams] - pooling_cursor: Optional[PoolingCursor] = None + pooling_cursor: PoolingCursor | None = None def __getitem__(self, indices: slice): return PoolingMetadata( diff --git a/vllm/v1/request.py b/vllm/v1/request.py index ac6e583099bc..864b0eb7fa41 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -3,9 +3,9 @@ import enum import time -from collections.abc import Mapping +from collections.abc import Callable, Mapping from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Optional import torch @@ -31,20 +31,19 @@ class Request: def __init__( self, request_id: str, - prompt_token_ids: Optional[list[int]], - sampling_params: Optional[SamplingParams], - pooling_params: Optional[PoolingParams], - eos_token_id: Optional[int], + prompt_token_ids: list[int] | None, + sampling_params: SamplingParams | None, + pooling_params: PoolingParams | None, + eos_token_id: int | None, client_index: int = 0, - arrival_time: Optional[float] = None, - prompt_embeds: Optional[torch.Tensor] = None, - mm_features: Optional[list[MultiModalFeatureSpec]] = None, + arrival_time: float | None = None, + prompt_embeds: torch.Tensor | None = None, + mm_features: list[MultiModalFeatureSpec] | None = None, lora_request: Optional["LoRARequest"] = None, - structured_output_request: Optional["StructuredOutputRequest"] = None, - cache_salt: Optional[str] = None, + cache_salt: str | None = None, priority: int = 0, - trace_headers: Optional[Mapping[str, str]] = None, - block_hasher: Optional[Callable[["Request"], list["BlockHash"]]] = None, + trace_headers: Mapping[str, str] | None = None, + block_hasher: Callable[["Request"], list["BlockHash"]] | None = None, ) -> None: self.request_id = request_id self.client_index = client_index @@ -54,16 +53,17 @@ def __init__( # Because of LoRA, the eos token id can be different for each request. self.eos_token_id = eos_token_id self.lora_request = lora_request - self.structured_output_request = structured_output_request + self.structured_output_request = StructuredOutputRequest.from_sampling_params( + sampling_params + ) self.arrival_time = arrival_time if arrival_time is not None else time.time() self.status = RequestStatus.WAITING - self.use_structured_output = False self.events: list[EngineCoreEvent] = [] - self.stop_reason: Union[int, str, None] = None + self.stop_reason: int | str | None = None # P/D: Connector-specific KV transfer parameters. - self.kv_transfer_params: Optional[dict[str, Any]] = None + self.kv_transfer_params: dict[str, Any] | None = None if pooling_params is not None: # Pooling models. @@ -72,9 +72,8 @@ def __init__( # Generative models. assert sampling_params.max_tokens is not None self.max_tokens = sampling_params.max_tokens - if sampling_params.structured_outputs is not None: + if self.structured_output_request is not None: self.status = RequestStatus.WAITING_FOR_FSM - self.use_structured_output = True if sampling_params.extra_args is not None: self.kv_transfer_params = sampling_params.extra_args.get( @@ -97,7 +96,7 @@ def __init__( self.num_output_placeholders = 0 # Used in async scheduling. self.spec_token_ids: list[int] = [] self.num_computed_tokens = 0 - self.cache_salt: Optional[str] = cache_salt + self.cache_salt: str | None = cache_salt # Multi-modal related self.mm_features = mm_features or [] @@ -123,7 +122,7 @@ def __init__( self.num_preemptions = 0 self.block_hashes: list[BlockHash] = [] - self.get_hash_new_full_blocks: Optional[Callable[[], list[BlockHash]]] = None + self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None if block_hasher is not None: self.get_hash_new_full_blocks = partial(block_hasher, self) self.block_hashes = self.get_hash_new_full_blocks() @@ -132,7 +131,7 @@ def __init__( def from_engine_core_request( cls, request: EngineCoreRequest, - block_hasher: Optional[Callable[["Request"], list["BlockHash"]]], + block_hasher: Callable[["Request"], list["BlockHash"]] | None, ) -> "Request": return cls( request_id=request.request_id, @@ -145,11 +144,6 @@ def from_engine_core_request( eos_token_id=request.eos_token_id, arrival_time=request.arrival_time, lora_request=request.lora_request, - structured_output_request=StructuredOutputRequest( - sampling_params=request.sampling_params - ) - if request.sampling_params - else None, cache_salt=request.cache_salt, priority=request.priority, trace_headers=request.trace_headers, @@ -158,7 +152,7 @@ def from_engine_core_request( def append_output_token_ids( self, - token_ids: Union[int, list[int]], + token_ids: int | list[int], ) -> None: if isinstance(token_ids, int): self._output_token_ids.append(token_ids) @@ -170,6 +164,10 @@ def append_output_token_ids( if self.get_hash_new_full_blocks is not None: self.block_hashes.extend(self.get_hash_new_full_blocks()) + @property + def use_structured_output(self) -> bool: + return self.structured_output_request is not None + @property def is_output_corrupted(self) -> bool: return self.num_nans_in_logits > 0 @@ -189,7 +187,7 @@ def num_output_tokens(self) -> int: def is_finished(self) -> bool: return RequestStatus.is_finished(self.status) - def get_finished_reason(self) -> Union[FinishReason, None]: + def get_finished_reason(self) -> FinishReason | None: return RequestStatus.get_finished_reason(self.status) def get_num_encoder_tokens(self, input_id: int) -> int: @@ -200,11 +198,11 @@ def get_num_encoder_tokens(self, input_id: int) -> int: def record_event( self, event_type: EngineCoreEventType, - timestamp: Optional[float] = None, + timestamp: float | None = None, ) -> None: self.events.append(EngineCoreEvent.new_event(event_type, timestamp)) - def take_events(self) -> Optional[list[EngineCoreEvent]]: + def take_events(self) -> list[EngineCoreEvent] | None: if not self.events: return None events, self.events = self.events, [] @@ -234,7 +232,7 @@ def is_finished(status: "RequestStatus") -> bool: return status > RequestStatus.PREEMPTED @staticmethod - def get_finished_reason(status: "RequestStatus") -> Union[FinishReason, None]: + def get_finished_reason(status: "RequestStatus") -> FinishReason | None: return _FINISHED_REASON_MAP.get(status) diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index e9935f72c17f..566de5bcda77 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -6,7 +6,7 @@ from abc import abstractmethod from collections.abc import Sequence from functools import partial -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING import torch @@ -55,12 +55,7 @@ def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]: """Load all installed logit processor plugins""" - import sys - - if sys.version_info < (3, 10): - from importlib_metadata import entry_points - else: - from importlib.metadata import entry_points + from importlib.metadata import entry_points installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP) if len(installed_logitsprocs_plugins) == 0: @@ -86,7 +81,7 @@ def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]: def _load_logitsprocs_by_fqcns( - logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]], + logits_processors: Sequence[str | type[LogitsProcessor]] | None, ) -> list[type[LogitsProcessor]]: """Load logit processor types, identifying them by fully-qualified class names (FQCNs). @@ -151,7 +146,7 @@ def _load_logitsprocs_by_fqcns( def _load_custom_logitsprocs( - logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]], + logits_processors: Sequence[str | type[LogitsProcessor]] | None, ) -> list[type[LogitsProcessor]]: """Load all custom logits processors. @@ -181,7 +176,7 @@ def build_logitsprocs( device: torch.device, is_pin_memory: bool, is_pooling_model: bool, - custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (), + custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = (), ) -> LogitsProcessors: if is_pooling_model: if custom_logitsprocs: @@ -254,7 +249,7 @@ def __init__( def new_req_logits_processor( self, params: SamplingParams, - ) -> Optional[RequestLogitsProcessor]: + ) -> RequestLogitsProcessor | None: """Consume request info; return a per-request logits processor. Return None if logits processor does not need to be applied to request @@ -272,9 +267,9 @@ def new_req_logits_processor( def _new_state( self, params: SamplingParams, - prompt_ids: Optional[list[int]], + prompt_ids: list[int] | None, output_ids: list[int], - ) -> Optional[partial[torch.Tensor]]: + ) -> partial[torch.Tensor] | None: """Return state representation for new request Returns None if logits processor is not applicable to request @@ -297,7 +292,7 @@ def _new_state( return partial(req_lp, *args) return None - def update_state(self, batch_update: Optional[BatchUpdate]): + def update_state(self, batch_update: BatchUpdate | None): process_dict_updates( self.req_info, batch_update, diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 3c3ddda7fb3e..4ee7dc2880c8 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Sequence -from typing import TYPE_CHECKING, Callable, Optional, TypeVar +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, TypeVar import torch @@ -49,7 +49,7 @@ def is_argmax_invariant(self) -> bool: def get_min_p_by_index(self, index: int) -> float: return float(self.min_p_cpu[index]) - def update_state(self, batch_update: Optional[BatchUpdate]): + def update_state(self, batch_update: BatchUpdate | None): if not batch_update: return @@ -131,7 +131,7 @@ def is_argmax_invariant(self) -> bool: outcome of argmax in greedy sampling.""" return False - def update_state(self, batch_update: Optional[BatchUpdate]): + def update_state(self, batch_update: BatchUpdate | None): needs_update = process_dict_updates( self.biases, batch_update, lambda params, _, __: params.logit_bias or None ) @@ -185,14 +185,14 @@ def is_argmax_invariant(self) -> bool: @staticmethod def add_request( - params: SamplingParams, _: Optional[list[int]], output_tok_ids: list[int] - ) -> Optional[tuple[int, Sequence[int], set[int]]]: + params: SamplingParams, _: list[int] | None, output_tok_ids: list[int] + ) -> tuple[int, Sequence[int], set[int]] | None: min_tokens = params.min_tokens if not min_tokens or len(output_tok_ids) >= min_tokens: return None return min_tokens, output_tok_ids, params.all_stop_token_ids - def update_state(self, batch_update: Optional[BatchUpdate]): + def update_state(self, batch_update: BatchUpdate | None): needs_update = process_dict_updates( self.min_toks, batch_update, self.add_request ) @@ -235,8 +235,8 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: def process_dict_updates( req_entries: dict[int, T], - batch_update: Optional[BatchUpdate], - new_state: Callable[[SamplingParams, Optional[list[int]], list[int]], Optional[T]], + batch_update: BatchUpdate | None, + new_state: Callable[[SamplingParams, list[int] | None, list[int]], T | None], ) -> bool: """Utility function to update dict state for sparse LogitsProcessors.""" diff --git a/vllm/v1/sample/logits_processor/interface.py b/vllm/v1/sample/logits_processor/interface.py index 713bd21d3855..efa0f62ad6e1 100644 --- a/vllm/v1/sample/logits_processor/interface.py +++ b/vllm/v1/sample/logits_processor/interface.py @@ -26,7 +26,7 @@ class MoveDirectionality(Enum): # (index, params, prompt_tok_ids, output_tok_ids) tuples for new # requests added to the batch. -AddedRequest = tuple[int, SamplingParams, Optional[list[int]], list[int]] +AddedRequest = tuple[int, SamplingParams, list[int] | None, list[int]] # (index 1, index 2, directionality) tuples representing # one-way moves or two-way swaps of requests in batch diff --git a/vllm/v1/sample/logits_processor/state.py b/vllm/v1/sample/logits_processor/state.py index a601f6641581..c15219da5cf7 100644 --- a/vllm/v1/sample/logits_processor/state.py +++ b/vllm/v1/sample/logits_processor/state.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterator from itertools import chain -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from vllm.v1.sample.logits_processor.interface import ( AddedRequest, @@ -43,9 +43,9 @@ class BatchUpdateBuilder: def __init__( self, - removed: Optional[list[RemovedRequest]] = None, - added: Optional[list[AddedRequest]] = None, - moved: Optional[list[MovedRequest]] = None, + removed: list[RemovedRequest] | None = None, + added: list[AddedRequest] | None = None, + moved: list[MovedRequest] | None = None, ) -> None: self._removed = removed or [] self.added = added or [] @@ -92,14 +92,14 @@ def removed_append(self, index: int) -> None: def has_removed(self) -> bool: return bool(self._removed) - def peek_removed(self) -> Optional[int]: + def peek_removed(self) -> int | None: """Return lowest removed request index""" if self.has_removed(): self._ensure_removed_sorted() return self._removed[-1] return None - def pop_removed(self) -> Optional[int]: + def pop_removed(self) -> int | None: """Pop lowest removed request index""" if self.has_removed(): self._ensure_removed_sorted() @@ -116,7 +116,7 @@ def reset(self) -> bool: self.batch_changed = False return batch_changed - def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]: + def get_and_reset(self, batch_size: int) -> BatchUpdate | None: """Generate a logitsprocs batch update data structure and reset internal batch update builder state. @@ -148,9 +148,7 @@ def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]: class LogitsProcessors: """Encapsulates initialized logitsproc objects.""" - def __init__( - self, logitsprocs: Optional[Iterator["LogitsProcessor"]] = None - ) -> None: + def __init__(self, logitsprocs: Iterator["LogitsProcessor"] | None = None) -> None: self.argmax_invariant: list[LogitsProcessor] = [] self.non_argmax_invariant: list[LogitsProcessor] = [] if logitsprocs: diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index e252ace97d27..b1101b1b2318 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Optional import torch @@ -11,20 +10,20 @@ @dataclass class SamplingMetadata: - temperature: Optional[torch.Tensor] + temperature: torch.Tensor | None all_greedy: bool all_random: bool - top_p: Optional[torch.Tensor] - top_k: Optional[torch.Tensor] + top_p: torch.Tensor | None + top_k: torch.Tensor | None generators: dict[int, torch.Generator] # None means no logprobs, 0 means sampled token logprobs only - max_num_logprobs: Optional[int] + max_num_logprobs: int | None no_penalties: bool - prompt_token_ids: Optional[torch.Tensor] + prompt_token_ids: torch.Tensor | None frequency_penalties: torch.Tensor presence_penalties: torch.Tensor repetition_penalties: torch.Tensor @@ -33,7 +32,7 @@ class SamplingMetadata: # `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size, # vocab size). - allowed_token_ids_mask: Optional[torch.Tensor] + allowed_token_ids_mask: torch.Tensor | None # req_index -> bad_words_token_ids bad_words_token_ids: dict[int, list[list[int]]] @@ -42,4 +41,4 @@ class SamplingMetadata: logitsprocs: LogitsProcessors # Speculative token ids - spec_token_ids: Optional[list[list[int]]] = None + spec_token_ids: list[list[int]] | None = None diff --git a/vllm/v1/sample/ops/penalties.py b/vllm/v1/sample/ops/penalties.py index e49b8db47800..898b90d41aba 100644 --- a/vllm/v1/sample/ops/penalties.py +++ b/vllm/v1/sample/ops/penalties.py @@ -4,7 +4,8 @@ import torch from vllm.model_executor.layers.utils import apply_penalties -from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.utils.torch_utils import make_tensor_with_pad def apply_all_penalties( diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index dbcdad07e4de..7a4b224822bd 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -1,26 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch import torch.nn as nn from packaging import version from vllm import envs -from vllm.config import LogprobsMode +from vllm.config.model import LogprobsMode from vllm.logger import init_logger -from vllm.platforms import current_platform +from vllm.platforms import CpuArchEnum, current_platform logger = init_logger(__name__) -try: - import flashinfer.sampling - - is_flashinfer_available = True -except ImportError: - is_flashinfer_available = False - class TopKTopPSampler(nn.Module): """ @@ -39,42 +31,30 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: logprobs_mode not in ("processed_logits", "processed_logprobs") and current_platform.is_cuda() ): - if is_flashinfer_available: - flashinfer_version = flashinfer.__version__ - if version.parse(flashinfer_version) < version.parse("0.2.3"): - logger.warning_once( - "FlashInfer version >= 0.2.3 required. " - "Falling back to default sampling implementation." - ) - self.forward = self.forward_native - elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False: - # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for - # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by - # default it is unused). For backward compatibility, we set - # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and - # interpret it differently in V0 and V1 samplers: In V0, - # None means False, while in V1, None means True. This is - # why we use the condition - # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. - logger.info_once("Using FlashInfer for top-p & top-k sampling.") - self.forward = self.forward_cuda - else: - logger.warning_once( - "FlashInfer is available, but it is not enabled. " - "Falling back to the PyTorch-native implementation of " - "top-p & top-k sampling. For the best performance, " - "please set VLLM_USE_FLASHINFER_SAMPLER=1." - ) - self.forward = self.forward_native + if envs.VLLM_USE_FLASHINFER_SAMPLER: + # Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1. + logger.info_once( + "Using FlashInfer for top-p & top-k sampling.", + scope="global", + ) + self.forward = self.forward_cuda else: - logger.warning_once( - "FlashInfer is not available. Falling back to the PyTorch-" - "native implementation of top-p & top-k sampling. For the " - "best performance, please install FlashInfer." + logger.debug_once( + "FlashInfer top-p/top-k sampling is available but disabled " + "by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in " + "after verifying accuracy for your workloads." ) self.forward = self.forward_native + elif current_platform.is_cpu(): - self.forward = self.forward_cpu + arch = current_platform.get_cpu_architecture() + # Fall back to native implementation for POWERPC and RISCV. + # On PowerPC argmax produces incorrect output with torch.compile. + # PR: https://github.com/vllm-project/vllm/pull/26987 + if arch in (CpuArchEnum.RISCV, CpuArchEnum.POWERPC): + self.forward = self.forward_native + else: + self.forward = self.forward_cpu else: self.forward = self.forward_native @@ -84,9 +64,9 @@ def forward_native( self, logits: torch.Tensor, generators: dict[int, torch.Generator], - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + k: torch.Tensor | None, + p: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: """ PyTorch-native implementation of top-k and top-p sampling. @@ -105,9 +85,9 @@ def forward_cuda( self, logits: torch.Tensor, generators: dict[int, torch.Generator], - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + k: torch.Tensor | None, + p: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: """More optimized implementation for top-k and top-p sampling.""" # We prefer `random_sample` over `flashinfer_sample` when sorting is # not needed. This is because `random_sample` does not require @@ -132,9 +112,9 @@ def forward_cpu( self, logits: torch.Tensor, generators: dict[int, torch.Generator], - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + k: torch.Tensor | None, + p: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: """ PyTorch-native implementation of top-k and top-p sampling for CPU. @@ -170,8 +150,8 @@ def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor: def apply_top_k_top_p( logits: torch.Tensor, - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], + k: torch.Tensor | None, + p: torch.Tensor | None, ) -> torch.Tensor: """Apply top-k and top-p masks to the logits. @@ -262,8 +242,8 @@ def random_sample( def flashinfer_sample( logits: torch.Tensor, - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], + k: torch.Tensor | None, + p: torch.Tensor | None, generators: dict[int, torch.Generator], ) -> torch.Tensor: """Sample from the logits using FlashInfer. @@ -280,6 +260,13 @@ def flashinfer_sample( does not. Call this function at the end of the forward pass to minimize the synchronization overhead. """ + import flashinfer + + if version.parse(flashinfer.__version__) < version.parse("0.2.3"): + raise ImportError( + "FlashInfer version >= 0.2.3 required for top-k and top-p sampling. " + ) + assert not (k is None and p is None) if k is None: # Top-p only. diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 76555a866685..926305d25f56 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -1,22 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional + +from dataclasses import replace import torch import torch.nn as nn from vllm.logger import init_logger from vllm.triton_utils import tl, triton +from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts from vllm.v1.sample.ops.penalties import apply_all_penalties from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p +from vllm.v1.sample.sampler import Sampler from vllm.v1.spec_decode.metadata import SpecDecodeMetadata logger = init_logger(__name__) PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 -GREEDY_TEMPERATURE: tl.constexpr = -1 +GREEDY_TEMPERATURE: tl.constexpr = 0 # Maximum number of speculative draft tokens allowed per request in a single # step. This value is chosen to be large enough to handle typical use cases. MAX_SPEC_LEN = 128 @@ -45,17 +48,22 @@ class RejectionSampler(nn.Module): output tokens = accepted tokens + recovered tokens + bonus tokens """ + def __init__(self, sampler: Sampler): + super().__init__() + self.sampler = sampler + logprobs_mode = self.sampler.logprobs_mode + self.is_processed_logprobs_mode = logprobs_mode.startswith("processed") + self.is_logits_logprobs_mode = logprobs_mode.endswith("logits") + def forward( self, metadata: SpecDecodeMetadata, # [num_tokens, vocab_size] - draft_probs: Optional[torch.Tensor], - # [num_tokens, vocab_size] - target_logits: torch.Tensor, - # [batch_size, 1] - bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor | None, + # [num_tokens + batch_size, vocab_size] + logits: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: + ) -> SamplerOutput: """ Args: metadata: @@ -64,43 +72,65 @@ def forward( Probability distribution for the draft tokens. Shape is [num_tokens, vocab_size]. Can be None if probabilities are not provided, which is the case for ngram spec decode. - target_logits (torch.Tensor): + logits (torch.Tensor): Target model's logits probability distribution. - Shape is [num_tokens, vocab_size]. Here, probabilities from - different requests are flattened into a single tensor because - this is the shape of the output logits. - NOTE: `target_logits` can be updated in place to save memory. - bonus_token_ids (torch.Tensor): - A tensor containing bonus tokens. Shape is [batch_size, 1]. - Bonus tokens are added to the end of the sequence if all - proposed tokens are accepted. We generate the bonus tokens - outside of the rejection sampler with the default sampling - strategy. It allows for more flexibility in the sampling - process such as top_p, top_k sampling. + Shape is [num_tokens + batch_size, vocab_size]. Here, + probabilities from different requests are flattened into a + single tensor because this is the shape of the output logits. + NOTE: `logits` can be updated in place to save memory. sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata): Additional metadata needed for sampling, such as temperature, top-k/top-p parameters, or other relevant information. Returns: - output_token_ids (torch.Tensor): - A tensor containing the final output token IDs. + SamplerOutput: + Contains the final output token IDs and their logprobs if + requested. """ assert metadata.max_spec_len <= MAX_SPEC_LEN - # Use float32 for the target_logits. - target_logits = target_logits.to(torch.float32) + bonus_logits_indices = metadata.bonus_logits_indices + target_logits_indices = metadata.target_logits_indices + + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. + assert logits is not None + bonus_logits = logits[bonus_logits_indices] + bonus_sampler_output = self.sampler( + logits=bonus_logits, + sampling_metadata=replace( + sampling_metadata, + max_num_logprobs=-1, + ), + predict_bonus_token=True, + # Override the logprobs mode to return logits because they are + # needed later to compute the accepted token logprobs. + logprobs_mode_override="processed_logits" + if self.is_processed_logprobs_mode + else "raw_logits", + ) + bonus_token_ids = bonus_sampler_output.sampled_token_ids + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. + raw_target_logits = logits[target_logits_indices] + # Use float32 for the target_logits. + raw_target_logits = raw_target_logits.to(torch.float32) target_logits = self.apply_logits_processors( - target_logits, sampling_metadata, metadata + raw_target_logits, sampling_metadata, metadata ) - # [num_tokens, vocab_size] # NOTE(woosuk): `target_logits` can be updated in place inside the - # `compute_probs` function. - target_probs = compute_probs( + # `apply_sampling_constraints` function. + target_logits = apply_sampling_constraints( target_logits, metadata.cu_num_draft_tokens, sampling_metadata, ) + # Compute probability distribution from target logits. + target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) output_token_ids = rejection_sample( metadata.draft_token_ids, @@ -112,7 +142,63 @@ def forward( bonus_token_ids, sampling_metadata, ) - return output_token_ids + + logprobs_tensors = None + if sampling_metadata.max_num_logprobs: + logprobs_tensors = self._get_logprobs_tensors( + sampling_metadata.max_num_logprobs, + metadata, + logits, + target_logits if self.is_processed_logprobs_mode else raw_target_logits, + bonus_sampler_output.logprobs_tensors.logprobs, + output_token_ids, + ) + + return SamplerOutput( + sampled_token_ids=output_token_ids, + logprobs_tensors=logprobs_tensors, + ) + + def _get_logprobs_tensors( + self, + max_num_logprobs: int, + metadata: SpecDecodeMetadata, + logits: torch.Tensor, + target_logits: torch.Tensor, + bonus_logits: torch.Tensor, + sampled_token_ids: torch.Tensor, + ) -> LogprobsTensors: + cu_num_sampled_tokens = torch.zeros_like(metadata.cu_num_sampled_tokens) + cu_num_sampled_tokens[1:] = metadata.cu_num_sampled_tokens[:-1] + + # Collect target and bonus logits. + bonus_logits_indices = metadata.bonus_logits_indices + target_logits_indices = metadata.target_logits_indices + final_logits = torch.zeros_like(logits, dtype=torch.float32) + final_logits[target_logits_indices] = target_logits.to(torch.float32) + final_logits[bonus_logits_indices] = bonus_logits.to(torch.float32) + + # Compute accepted token indices. + accepted_mask = sampled_token_ids != PLACEHOLDER_TOKEN_ID + num_accepted_tokens = accepted_mask.sum(dim=-1) + accepted_logit_indices = accepted_mask.nonzero(as_tuple=True)[1] + accepted_logit_indices += cu_num_sampled_tokens.repeat_interleave( + num_accepted_tokens + ) + + # Compute logprobs for accepted tokens. + accepted_logits = final_logits[accepted_logit_indices] + accepted_logprobs = ( + accepted_logits + if self.is_logits_logprobs_mode + else self.sampler.compute_logprobs(accepted_logits) + ) + accepted_tokens = sampled_token_ids[accepted_mask] + return self.sampler.gather_logprobs( + accepted_logprobs, + max_num_logprobs, + accepted_tokens.to(torch.int64), + ) @staticmethod def parse_output( @@ -120,14 +206,12 @@ def parse_output( vocab_size: int, ) -> list[list[int]]: """Parse the output of the rejection sampler. - Args: output_token_ids: The sampled token IDs in shape [batch_size, max_spec_len + 1]. The rejected tokens are replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler and will be filtered out in this function. vocab_size: The size of the vocabulary. - Returns: A list of lists of token IDs. """ @@ -147,22 +231,20 @@ def apply_logits_processors( sampling_metadata: SamplingMetadata, metadata: SpecDecodeMetadata, ) -> torch.Tensor: + has_penalties = not sampling_metadata.no_penalties any_penalties_or_bad_words = ( - sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties + sampling_metadata.bad_words_token_ids or has_penalties ) output_token_ids = sampling_metadata.output_token_ids if any_penalties_or_bad_words: output_token_ids = self._combine_outputs_with_spec_tokens( - sampling_metadata.output_token_ids, + output_token_ids, sampling_metadata.spec_token_ids, ) # Calculate indices of target logits. - if ( - sampling_metadata.allowed_token_ids_mask is not None - or not sampling_metadata.no_penalties - ): + if sampling_metadata.allowed_token_ids_mask is not None or has_penalties: num_requests = len(sampling_metadata.output_token_ids) num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu") original_indices = torch.arange(num_requests, device="cpu") @@ -180,18 +262,15 @@ def apply_logits_processors( logits.masked_fill_(token_mask, float("-inf")) # Apply bad words exclusion. - if sampling_metadata.bad_words_token_ids: + if bad_words_token_ids := sampling_metadata.bad_words_token_ids: apply_bad_words_with_drafts( - logits, - sampling_metadata.bad_words_token_ids, - output_token_ids, - metadata.num_draft_tokens, + logits, bad_words_token_ids, output_token_ids, metadata.num_draft_tokens ) return logits + @staticmethod def apply_penalties( - self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, metadata: SpecDecodeMetadata, @@ -218,10 +297,10 @@ def apply_penalties( ) return logits + @staticmethod def _combine_outputs_with_spec_tokens( - self, output_token_ids: list[list[int]], - spec_token_ids: Optional[list[list[int]]] = None, + spec_token_ids: list[list[int]] | None = None, ) -> list[list[int]]: if spec_token_ids is None: return output_token_ids @@ -245,7 +324,7 @@ def rejection_sample( # [batch_size] cu_num_draft_tokens: torch.Tensor, # [num_tokens, vocab_size] - draft_probs: Optional[torch.Tensor], + draft_probs: torch.Tensor | None, # [num_tokens, vocab_size] target_probs: torch.Tensor, # [batch_size, 1] @@ -334,27 +413,26 @@ def rejection_sample( return output_token_ids -def compute_probs( +def apply_sampling_constraints( logits: torch.Tensor, # [num_tokens, vocab_size] cu_num_draft_tokens: torch.Tensor, # [batch_size] sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - """Compute probability distribution from logits based on sampling metadata. + """Process logits based on sampling metadata. - This function applies temperature scaling to the logits and converts - them to probabilities using softmax. For greedy decoding, it returns + This function applies temperature scaling to the logits, + as well as top-k and top-p. For greedy decoding, it returns the original logits. Args: - logits: Input logits tensor to be converted to probabilities. + logits: Input logits tensor to be processed. cu_num_draft_tokens: Cumulative number of draft tokens. sampling_metadata: Metadata containing sampling parameters such as temperature and whether greedy sampling is used. Returns: - torch.Tensor: Probability distribution (softmax of scaled logits) - if non-greedy sampling is used, otherwise returns the - original logits. + torch.Tensor: Processed logits if non-greedy sampling is used, + otherwise returns the original logits. """ assert logits.ndim == 2 assert cu_num_draft_tokens.ndim == 1 @@ -390,9 +468,7 @@ def compute_probs( # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask, # which is slow for large vocab sizes. This may cause performance issues. - logits = apply_top_k_top_p(logits, top_k, top_p) - output_prob = logits.softmax(dim=-1, dtype=torch.float32) - return output_prob + return apply_top_k_top_p(logits, top_k, top_p) def expand_batch_to_tokens( @@ -498,7 +574,7 @@ def sample_recovered_tokens( # [num_tokens] draft_token_ids: torch.Tensor, # [num_tokens, vocab_size] - draft_probs: Optional[torch.Tensor], + draft_probs: torch.Tensor | None, # [num_tokens, vocab_size] target_probs: torch.Tensor, sampling_metadata: SamplingMetadata, diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 101d2ebed4b7..39c63fe31ad2 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -2,13 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A layer that samples the next tokens from the model's outputs.""" -from typing import Optional - import torch import torch.nn as nn -from vllm.config import LogprobsMode -from vllm.utils import is_pin_memory_available +from vllm.config.model import LogprobsMode +from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.bad_words import apply_bad_words @@ -71,16 +69,18 @@ def forward( logits: torch.Tensor, sampling_metadata: SamplingMetadata, predict_bonus_token: bool = False, + logprobs_mode_override: LogprobsMode | None = None, ) -> SamplerOutput: + logprobs_mode = logprobs_mode_override or self.logprobs_mode # NOTE(woosuk): Use the original logits (before any penalties or # temperature scaling) for the top-k logprobs. # This is different from the V0 sampler, which uses the logits that # is used for sampling (after penalties and temperature scaling). num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: - if self.logprobs_mode == "raw_logprobs": + if logprobs_mode == "raw_logprobs": raw_logprobs = self.compute_logprobs(logits) - elif self.logprobs_mode == "raw_logits": + elif logprobs_mode == "raw_logits": raw_logprobs = logits.clone() # Use float32 for the logits. @@ -99,13 +99,18 @@ def forward( # return int32 (while PyTorch argmax and topk return int64). sampled = sampled.long() - # Gather the logprobs of the topk and sampled token (if requested). - # Get logprobs and rank tensors (if requested) - logprobs_tensors = ( - None - if num_logprobs is None - else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled) - ) + if num_logprobs is None: + logprobs_tensors = None + elif num_logprobs == -1: + # Return the full unsorted and unranked logprobs. + logprobs_tensors = LogprobsTensors( + torch.empty(0), raw_logprobs, torch.empty(0) + ) + else: + # Gather the logprobs and ranks of the topk and sampled token. + logprobs_tensors = self.gather_logprobs( + raw_logprobs, num_logprobs, token_ids=sampled + ) # Use int32 to reduce the tensor size. sampled = sampled.to(torch.int32) @@ -120,8 +125,8 @@ def forward( ) return sampler_output + @staticmethod def apply_temperature( - self, logits: torch.Tensor, temp: torch.Tensor, all_random: bool, @@ -132,20 +137,23 @@ def apply_temperature( temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) return logits.div_(temp.unsqueeze(dim=1)) - def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: + @staticmethod + def greedy_sample(logits: torch.Tensor) -> torch.Tensor: return logits.argmax(dim=-1).view(-1) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + logprobs_mode_override: LogprobsMode | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: """Sample logits based on sampling metadata. The various logits processing functions called in this method may update the logits tensor in-place. """ + logprobs_mode = logprobs_mode_override or self.logprobs_mode assert not (sampling_metadata.all_greedy and sampling_metadata.all_random) if sampling_metadata.all_random: greedy_sampled = None @@ -154,9 +162,9 @@ def sample( if sampling_metadata.all_greedy: processed_logprobs = None if sampling_metadata.max_num_logprobs is not None: - if self.logprobs_mode == "processed_logits": + if logprobs_mode == "processed_logits": processed_logprobs = logits - elif self.logprobs_mode == "processed_logprobs": + elif logprobs_mode == "processed_logprobs": processed_logprobs = self.compute_logprobs(logits) return greedy_sampled, processed_logprobs @@ -191,11 +199,12 @@ def sample( ) return sampled, processed_logprobs - def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: + @staticmethod + def compute_logprobs(logits: torch.Tensor) -> torch.Tensor: return logits.log_softmax(dim=-1, dtype=torch.float32) + @staticmethod def gather_logprobs( - self, logprobs: torch.Tensor, num_logprobs: int, token_ids: torch.Tensor, @@ -238,10 +247,10 @@ def gather_logprobs( return LogprobsTensors(indices, logprobs, token_ranks) + @staticmethod def _combine_outputs_with_spec_tokens( - self, output_token_ids: list[list[int]], - spec_token_ids: Optional[list[list[int]]] = None, + spec_token_ids: list[list[int]] | None = None, ) -> list[list[int]]: if spec_token_ids is None: return output_token_ids @@ -257,8 +266,9 @@ def apply_logits_processors( sampling_metadata: SamplingMetadata, predict_bonus_token: bool, ) -> torch.Tensor: + bad_words_token_ids = sampling_metadata.bad_words_token_ids any_penalties_or_bad_words = ( - sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties + bool(bad_words_token_ids) or not sampling_metadata.no_penalties ) output_token_ids = sampling_metadata.output_token_ids @@ -266,7 +276,7 @@ def apply_logits_processors( # Combine base outputs with spec tokens when speculative decoding # is enabled. output_token_ids = self._combine_outputs_with_spec_tokens( - sampling_metadata.output_token_ids, + output_token_ids, sampling_metadata.spec_token_ids, ) @@ -275,14 +285,8 @@ def apply_logits_processors( logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf")) # Apply bad words exclusion. - if sampling_metadata.bad_words_token_ids: - apply_bad_words( - logits, - sampling_metadata.bad_words_token_ids, - output_token_ids - if output_token_ids is not None - else sampling_metadata.output_token_ids, - ) + if bad_words_token_ids: + apply_bad_words(logits, bad_words_token_ids, output_token_ids) # Apply logits processors which can impact greedy sampling. for processor in sampling_metadata.logitsprocs.non_argmax_invariant: @@ -292,22 +296,21 @@ def apply_logits_processors( logits = self.apply_penalties(logits, sampling_metadata, output_token_ids) return logits + @staticmethod def apply_penalties( - self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, - output_token_ids: Optional[list[list[int]]] = None, + output_token_ids: list[list[int]], ) -> torch.Tensor: - if not sampling_metadata.no_penalties: - assert sampling_metadata.prompt_token_ids is not None - logits = apply_all_penalties( - logits, - sampling_metadata.prompt_token_ids, - sampling_metadata.presence_penalties, - sampling_metadata.frequency_penalties, - sampling_metadata.repetition_penalties, - output_token_ids - if output_token_ids is not None - else sampling_metadata.output_token_ids, - ) - return logits + if sampling_metadata.no_penalties: + return logits + + assert sampling_metadata.prompt_token_ids is not None + return apply_all_penalties( + logits, + sampling_metadata.prompt_token_ids, + sampling_metadata.presence_penalties, + sampling_metadata.frequency_penalties, + sampling_metadata.repetition_penalties, + output_token_ids, + ) diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index b58a94d0bf7d..0c1a22e84ece 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass, field -from typing import Optional import torch @@ -31,6 +30,7 @@ class TPUSupportedSamplingMetadata: top_p: torch.Tensor = None all_greedy: bool = True + all_random: bool = False # Whether logprobs are to be gathered in this batch of request. To balance # out compile time and runtime, a fixed `max_number_logprobs` value is used @@ -48,7 +48,7 @@ class TPUSupportedSamplingMetadata: min_tokens = None # impl is not vectorized - logit_bias: list[Optional[dict[int, float]]] = field(default_factory=lambda: list()) + logit_bias: list[dict[int, float] | None] = field(default_factory=lambda: list()) allowed_token_ids_mask = None bad_words_token_ids = None @@ -111,6 +111,7 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor: xla_device ), all_greedy=input_batch.all_greedy, + all_random=input_batch.all_random, # TODO enable more and avoid returning None values top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(xla_device), top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(xla_device), diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py index ccef283a8182..8f0463c76ce1 100644 --- a/vllm/v1/sample/tpu/sampler.py +++ b/vllm/v1/sample/tpu/sampler.py @@ -2,8 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Sampler layer implementing TPU supported operations.""" -from typing import Optional - import torch import torch.nn as nn @@ -42,7 +40,11 @@ def apply_temperature( self, logits: torch.Tensor, temp: torch.Tensor, + all_random: bool = False, ) -> torch.Tensor: + # Avoid division by zero for greedy sampling (temperature ~ 0.0). + if not all_random: + temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) return logits.div_(temp.unsqueeze(dim=1)) def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: @@ -58,7 +60,9 @@ def sample( assert sampling_metadata.temperature is not None # Apply temperature. - logits = self.apply_temperature(logits, sampling_metadata.temperature) + logits = self.apply_temperature( + logits, sampling_metadata.temperature, sampling_metadata.all_random + ) # Apply min_p. if sampling_metadata.min_p is not None: @@ -166,8 +170,8 @@ def random_sample( def apply_top_k_top_p( logits: torch.Tensor, - k: Optional[torch.Tensor], - p: Optional[torch.Tensor], + k: torch.Tensor | None, + p: torch.Tensor | None, ) -> torch.Tensor: """ Apply top-k and top-p optimized for TPU. diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 747d08dcd367..39147a67d6cf 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -4,10 +4,10 @@ import dataclasses import importlib import pickle -from collections.abc import Sequence +from collections.abc import Callable, Sequence from inspect import isclass from types import FunctionType -from typing import Any, Callable, Optional, Union +from typing import Any, TypeAlias import cloudpickle import msgspec @@ -31,6 +31,7 @@ NestedTensors, ) from vllm.v1.engine import UtilityResult +from vllm.v1.utils import tensor_data logger = init_logger(__name__) @@ -47,7 +48,7 @@ MultiModalBatchedField: "batched", } -bytestr = Union[bytes, bytearray, memoryview, zmq.Frame] +bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame def _log_insecure_serialization_warning(): @@ -57,7 +58,7 @@ def _log_insecure_serialization_warning(): ) -def _typestr(val: Any) -> Optional[tuple[str, str]]: +def _typestr(val: Any) -> tuple[str, str] | None: if val is None: return None t = type(val) @@ -111,14 +112,14 @@ class MsgpackEncoder: via dedicated messages. Note that this is a per-tensor limit. """ - def __init__(self, size_threshold: Optional[int] = None): + def __init__(self, size_threshold: int | None = None): if size_threshold is None: size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) # This is used as a local stash of buffers that we can then access from # our custom `msgspec` hook, `enc_hook`. We don't have a way to # pass custom data to the hook otherwise. - self.aux_buffers: Optional[list[bytestr]] = None + self.aux_buffers: list[bytestr] | None = None self.size_threshold = size_threshold if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: _log_insecure_serialization_warning() @@ -195,7 +196,7 @@ def enc_hook(self, obj: Any) -> Any: def _encode_ndarray( self, obj: np.ndarray - ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: + ) -> tuple[str, tuple[int, ...], int | memoryview]: assert self.aux_buffers is not None # If the array is non-contiguous, we need to copy it first arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes() @@ -215,17 +216,17 @@ def _encode_ndarray( def _encode_tensor( self, obj: torch.Tensor - ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: + ) -> tuple[str, tuple[int, ...], int | memoryview]: assert self.aux_buffers is not None # view the tensor as a contiguous 1D array of bytes - arr = obj.flatten().contiguous().view(torch.uint8).numpy() + arr_data = tensor_data(obj) if obj.nbytes < self.size_threshold: # Smaller tensors are encoded inline, just like ndarrays. - data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) + data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data) else: # Otherwise encode index of backing buffer to avoid copy. data = len(self.aux_buffers) - self.aux_buffers.append(arr.data) + self.aux_buffers.append(arr_data) dtype = str(obj.dtype).removeprefix("torch.") return dtype, obj.shape, data @@ -280,7 +281,7 @@ class MsgpackDecoder: not thread-safe when encoding tensors / numpy arrays. """ - def __init__(self, t: Optional[Any] = None): + def __init__(self, t: Any | None = None): args = () if t is None else (t,) self.decoder = msgpack.Decoder( *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook @@ -289,10 +290,8 @@ def __init__(self, t: Optional[Any] = None): if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: _log_insecure_serialization_warning() - def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any: - if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)): - # TODO - This check can become `isinstance(bufs, bytestr)` - # as of Python 3.10. + def decode(self, bufs: bytestr | Sequence[bytestr]) -> Any: + if isinstance(bufs, bytestr): # type: ignore return self.decoder.decode(bufs) self.aux_buffers = bufs diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1e1161727be1..35c2e73e8ee2 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -3,24 +3,28 @@ import ast from dataclasses import replace from importlib.util import find_spec -from typing import Optional import numpy as np import torch import torch.nn as nn -from vllm.attention.layer import Attention -from vllm.config import CompilationLevel, VllmConfig, get_layers_from_vllm_config +from vllm.config import ( + CompilationMode, + CUDAGraphMode, + VllmConfig, + get_layers_from_vllm_config, +) from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform -from vllm.utils import is_pin_memory_available +from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.tree_attn import ( TreeAttentionMetadata, @@ -33,10 +37,10 @@ ) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.sampler import _SAMPLING_EPS from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from vllm.v1.worker.ubatching import dbo_current_ubatch_id logger = init_logger(__name__) @@ -75,19 +79,32 @@ def __init__( vllm_config.model_config ) - self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None - self.draft_indexer_metadata_builder: Optional[AttentionMetadataBuilder] = None + self.attn_metadata_builder: AttentionMetadataBuilder | None = None + self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None self.attn_layer_names: list[str] = [] self.indexer_layer_names: list[str] = [] - self.use_cuda_graph = ( - not current_platform.is_xpu() - and self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE - and not self.vllm_config.model_config.enforce_eager - and not self.speculative_config.enforce_eager - ) + self.use_cuda_graph = False + + compilation_config = self.vllm_config.compilation_config + if compilation_config.mode == CompilationMode.VLLM_COMPILE: + cudagraph_mode = compilation_config.cudagraph_mode + if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode( + CUDAGraphMode.PIECEWISE + ): + logger.warning( + "Currently the eagle proposer only supports cudagraph_mode " + "PIECEWISE, if you want the drafter to use cuda graphs, " + "please set compilation_config.cudagraph_mode to PIECEWISE " + "or FULL_AND_PIECEWISE" + ) + self.use_cuda_graph = ( + cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE) + and not self.speculative_config.enforce_eager + ) + self.cudagraph_batch_sizes = ( - list(reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes)) + (sorted(self.vllm_config.compilation_config.cudagraph_capture_sizes)) if self.use_cuda_graph else [] ) @@ -132,7 +149,7 @@ def __init__( ) # Determine allowed attention backends once during initialization. - self.allowed_attn_types: Optional[tuple] = None + self.allowed_attn_types: tuple | None = None if current_platform.is_rocm(): rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend @@ -190,10 +207,10 @@ def propose( target_hidden_states: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, - last_token_indices: Optional[torch.Tensor], + last_token_indices: torch.Tensor | None, common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, - mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]] = None, + mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] @@ -216,11 +233,11 @@ def propose( assert self.runner is not None - # FIXME: need to consider multiple kv_cache_groups - ubatch_id = dbo_current_ubatch_id() - attn_metadata_builder = self.runner.attn_groups[0][0].metadata_builders[ - ubatch_id - ] + if self.attn_metadata_builder is None: + attn_metadata_builder = self._get_attention_metadata_builder() + else: + attn_metadata_builder = self.attn_metadata_builder + attn_metadata = attn_metadata_builder.build_for_drafting( common_attn_metadata=common_attn_metadata, draft_index=0 ) @@ -239,12 +256,15 @@ def propose( per_layer_attn_metadata = {} for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata + for layer_name in self.indexer_layer_names: assert draft_indexer_metadata is not None per_layer_attn_metadata[layer_name] = draft_indexer_metadata + cudagraph_runtime_mode = CUDAGraphMode.NONE if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE else: num_input_tokens = num_tokens # copy inputs to buffer for cudagraph @@ -267,7 +287,10 @@ def propose( inputs_embeds = None with set_forward_context( - per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens + per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, ): ret_hidden_states = self.model( input_ids=input_ids, @@ -326,8 +349,10 @@ def propose( if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE else: input_batch_size = batch_size + cudagraph_runtime_mode = CUDAGraphMode.NONE common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.max_query_len = 1 @@ -424,7 +449,10 @@ def propose( # Run the model. with set_forward_context( - per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size + per_layer_attn_metadata, + self.vllm_config, + num_tokens=input_batch_size, + cudagraph_runtime_mode=cudagraph_runtime_mode, ): ret_hidden_states = self.model( input_ids=input_ids, @@ -597,6 +625,7 @@ def prepare_inputs_padded( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, + dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, ) token_indices_to_sample = ( @@ -730,11 +759,16 @@ def propose_tree( if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE else: num_input_tokens = num_tokens + cudagraph_runtime_mode = CUDAGraphMode.NONE # Run the model. with set_forward_context( - per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens + per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, ): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], @@ -868,6 +902,7 @@ def prepare_inputs( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, + dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, ) return spec_common_attn_metadata, token_indices @@ -880,7 +915,7 @@ def get_model_name(self, model: nn.Module) -> str: def load_model(self, target_model: nn.Module) -> None: draft_model_config = self.vllm_config.speculative_config.draft_model_config target_attn_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, Attention).keys() + get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() ) # FIXME: support hybrid kv for draft model target_indexer_layer_names = set( @@ -897,7 +932,7 @@ def load_model(self, target_model: nn.Module) -> None: ) draft_attn_layer_names = ( - get_layers_from_vllm_config(self.vllm_config, Attention).keys() + get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() - target_attn_layer_names ) indexer_layers = get_layers_from_vllm_config( @@ -913,7 +948,7 @@ def load_model(self, target_model: nn.Module) -> None: indexer_layers[first_layer] .get_attn_backend() .get_builder_cls()( - indexer_layers[first_layer].get_kv_cache_spec(), + indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config), self.indexer_layer_names, self.vllm_config, self.device, @@ -1013,8 +1048,19 @@ def load_model(self, target_model: nn.Module) -> None: def dummy_run( self, num_tokens: int, + use_cudagraphs=True, ) -> None: - with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): + if use_cudagraphs and num_tokens <= self.cudagraph_batch_sizes[-1]: + num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE + if use_cudagraphs + else CUDAGraphMode.NONE, + ): if self.supports_mm_inputs: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -1029,7 +1075,7 @@ def dummy_run( inputs_embeds=inputs_embeds, ) - def _get_attention_metadata_builder(self) -> list[AttentionMetadataBuilder]: + def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder: """Find and return the attention metadata builders for EAGLE layers. Returns: @@ -1095,8 +1141,15 @@ def compute_probs_and_sample_next_token( next_token_ids = logits.argmax(dim=-1) return next_token_ids, probs - is_greedy = sampling_metadata.temperature == -1 - temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) + assert sampling_metadata.temperature is not None + + # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0) + # consistent with sampler.py's _SAMPLING_EPS threshold + temperature = sampling_metadata.temperature + # Avoid division by zero if there are greedy requests. + if not sampling_metadata.all_random: + is_greedy = temperature < _SAMPLING_EPS + temperature = torch.where(is_greedy, 1.0, temperature) logits.div_(temperature.view(-1, 1)) probs = logits.softmax(dim=-1, dtype=torch.float32) diff --git a/vllm/v1/spec_decode/metadata.py b/vllm/v1/spec_decode/metadata.py index d0695244cb16..6955ae79d01d 100644 --- a/vllm/v1/spec_decode/metadata.py +++ b/vllm/v1/spec_decode/metadata.py @@ -14,6 +14,8 @@ class SpecDecodeMetadata: num_draft_tokens: list[int] # [batch_size] cu_num_draft_tokens: torch.Tensor + # [batch_size] + cu_num_sampled_tokens: torch.Tensor # [num_tokens] target_logits_indices: torch.Tensor # [batch_size] @@ -32,6 +34,7 @@ def make_dummy( ) -> "SpecDecodeMetadata": batch_size = len(draft_token_ids) num_draft_tokens = [len(ids) for ids in draft_token_ids] + num_sampled_tokens = [len(ids) + 1 for ids in draft_token_ids] flattened_draft_token_ids = sum(draft_token_ids, []) num_tokens = len(flattened_draft_token_ids) @@ -40,6 +43,10 @@ def make_dummy( ) cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32) cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device) + cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32) + cu_num_sampled_tokens_tensor = torch.from_numpy(cu_num_sampled_tokens).to( + device + ) target_logits_indices = torch.zeros( num_tokens, dtype=torch.int32, device=device @@ -52,6 +59,7 @@ def make_dummy( draft_token_ids=draft_token_ids_tensor, num_draft_tokens=num_draft_tokens, cu_num_draft_tokens=cu_num_draft_tokens_tensor, + cu_num_sampled_tokens=cu_num_sampled_tokens_tensor, target_logits_indices=target_logits_indices, bonus_logits_indices=bonus_logits_indices, logits_indices=logits_indices, diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 89a8a11a3d56..79d856a143ba 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -3,7 +3,6 @@ import time from dataclasses import dataclass, field -from typing import Optional import numpy as np import prometheus_client @@ -143,7 +142,7 @@ class SpecDecodingProm: def __init__( self, - speculative_config: Optional[SpeculativeConfig], + speculative_config: SpeculativeConfig | None, labelnames: list[str], per_engine_labelvalues: dict[int, list[str]], ): diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 1b5e75313d89..6f9dbeabd8ca 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import multiprocessing from concurrent.futures import Future, ThreadPoolExecutor from typing import TYPE_CHECKING @@ -10,7 +8,7 @@ from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs -from vllm.utils import LazyLoader +from vllm.utils.import_utils import LazyLoader from vllm.v1.structured_output.backend_guidance import GuidanceBackend from vllm.v1.structured_output.backend_types import ( StructuredOutputBackend, @@ -28,6 +26,9 @@ else: torch = LazyLoader("torch", globals(), "torch") + ReasoningParser = object + Request = object + logger = init_logger(__name__) @@ -72,6 +73,10 @@ def __init__(self, vllm_config: VllmConfig): ) self.reasoner = reasoner_cls(tokenizer=self.tokenizer) + self.enable_in_reasoning = ( + self.vllm_config.structured_outputs_config.enable_in_reasoning + ) + def grammar_init(self, request: Request) -> None: if request.structured_output_request is None: return @@ -166,9 +171,9 @@ def _async_submit_fill_bitmask( def grammar_bitmask( self, requests: dict[str, Request], - structured_output_request_ids: dict[str, int], + structured_output_request_ids: list[str], scheduled_spec_decode_tokens: dict[str, list[int]], - ) -> npt.NDArray[np.int32] | None: + ) -> "npt.NDArray[np.int32] | None": # Prepare the structured output bitmask for this batch. if not structured_output_request_ids: return None @@ -195,17 +200,16 @@ def grammar_bitmask( # masks for each request, one for each possible bonus token position. # These are stored inline in the tensor and unpacked by the gpu runner. cumulative_index = 0 - ordered_seq = sorted(structured_output_request_ids.items(), key=lambda x: x[1]) # Optimized parallel filling of bitmasks for # non-spec, large-batch-size cases if ( - len(ordered_seq) > self.fill_bitmask_parallel_threshold + len(structured_output_request_ids) > self.fill_bitmask_parallel_threshold and max_num_spec_tokens == 0 ): promises = [] batch = [] - for req_id, _ in ordered_seq: + for req_id in structured_output_request_ids: request = requests[req_id] structured_output_request = request.structured_output_request if TYPE_CHECKING: @@ -229,7 +233,7 @@ def grammar_bitmask( promise.result() else: # Fallback to serial filling of bitmasks for small-batch-size cases - for req_id, _ in ordered_seq: + for req_id in structured_output_request_ids: request = requests[req_id] structured_output_request = request.structured_output_request @@ -274,7 +278,13 @@ def grammar_bitmask( return bitmask_tensor.numpy() def should_fill_bitmask(self, request: Request) -> bool: + # NOTE (Hanchen) if enable_in_reasoning is True, it means that + # the model needs to be constrained in reasoning. So we should always + # enable the bitmask filling. + if self.reasoner is not None: + if self.enable_in_reasoning: + return True assert request.structured_output_request is not None if request.structured_output_request.reasoning_ended is None: request.structured_output_request.reasoning_ended = ( @@ -294,22 +304,25 @@ def should_advance(self, request: Request) -> bool: assert request.structured_output_request.grammar is not None # by default, we should always advance # for cases that don't use thinking mode. - if self.reasoner is not None: - structured_req = request.structured_output_request - - if structured_req.reasoning_ended: - return True + if self.reasoner is None: + return True - # Check if reasoning ends in *this* step - if self.reasoner.is_reasoning_end(request.all_token_ids): - # Reasoning just ended, so we shouldn't advance til - # next pass - structured_req.reasoning_ended = True + # if the model needs structured in reasoning, we should advance + if self.enable_in_reasoning: + return True - return False - else: + structured_req = request.structured_output_request + if structured_req.reasoning_ended: return True + # Check if reasoning ends in *this* step + if self.reasoner.is_reasoning_end(request.all_token_ids): + # Reasoning just ended, so we shouldn't advance til + # next pass + structured_req.reasoning_ended = True + + return False + def clear_backend(self) -> None: if self.backend is not None: self.backend.destroy() diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 081cdfdc9932..00a625e103bd 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -1,19 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import copy import json import os from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any import torch from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.utils import LazyLoader +from vllm.utils.import_utils import LazyLoader from vllm.v1.structured_output.backend_types import ( StructuredOutputBackend, StructuredOutputGrammar, @@ -47,7 +45,7 @@ def _walk_json_for_additional_properties(data: object): def process_for_additional_properties( - guide_json: Union[str, dict[str, Any]], + guide_json: str | dict[str, Any], ) -> dict[str, Any]: if isinstance(guide_json, str): guide_json_obj = json.loads(guide_json) @@ -184,12 +182,12 @@ def reset(self): def serialize_guidance_grammar( request_type: StructuredOutputOptions, - grammar_spec: Union[str, dict[str, Any]], + grammar_spec: str | dict[str, Any], disable_any_whitespace: bool = False, disable_additional_properties: bool = False, ) -> str: def _process_schema( - grammar_spec: Union[str, dict[str, Any]], + grammar_spec: str | dict[str, Any], ) -> str: if disable_additional_properties: grammar_spec = process_for_additional_properties(grammar_spec) @@ -254,7 +252,7 @@ def _process_schema( def validate_guidance_grammar( sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None ) -> None: - tp, grm = get_structured_output_key(sampling_params) + tp, grm = get_structured_output_key(sampling_params.structured_outputs) guidance_grm = serialize_guidance_grammar(tp, grm) err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer) if err: diff --git a/vllm/v1/structured_output/backend_lm_format_enforcer.py b/vllm/v1/structured_output/backend_lm_format_enforcer.py index d9e484092d6a..150c57feda0f 100644 --- a/vllm/v1/structured_output/backend_lm_format_enforcer.py +++ b/vllm/v1/structured_output/backend_lm_format_enforcer.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import ast import json from dataclasses import dataclass, field @@ -12,7 +10,7 @@ from transformers import PreTrainedTokenizerBase from vllm.sampling_params import SamplingParams -from vllm.utils import LazyLoader +from vllm.utils.import_utils import LazyLoader from vllm.v1.structured_output.backend_types import ( StructuredOutputBackend, StructuredOutputGrammar, @@ -34,7 +32,7 @@ @lru_cache def _cached_build_vllm_token_enforcer_tokenizer_data( tokenizer: PreTrainedTokenizerBase, vocab_size: int -) -> lmfe_vllm.TokenEnforcerTokenizerData: +) -> "lmfe_vllm.TokenEnforcerTokenizerData": return lmfe_vllm.build_vllm_token_enforcer_tokenizer_data( tokenizer, use_bitmask=True, vocab_size=vocab_size ) diff --git a/vllm/v1/structured_output/backend_outlines.py b/vllm/v1/structured_output/backend_outlines.py index c9875337179e..34916079f821 100644 --- a/vllm/v1/structured_output/backend_outlines.py +++ b/vllm/v1/structured_output/backend_outlines.py @@ -14,7 +14,7 @@ from regex import escape as regex_escape from vllm.sampling_params import SamplingParams -from vllm.utils import LazyLoader +from vllm.utils.import_utils import LazyLoader from vllm.v1.structured_output.backend_types import ( StructuredOutputBackend, StructuredOutputGrammar, diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index 2051b336e5bf..7dc9589b63b8 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import enum from abc import ABC, abstractmethod from dataclasses import dataclass @@ -13,6 +11,9 @@ from vllm.config import VllmConfig from vllm.transformers_utils.tokenizer import AnyTokenizer +else: + VllmConfig = object + AnyTokenizer = object class StructuredOutputOptions(enum.Enum): @@ -69,7 +70,7 @@ def rollback(self, num_tokens: int) -> None: """ @abstractmethod - def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None: + def fill_bitmask(self, bitmask: "torch.Tensor", batch_index: int) -> None: """ Fills the bitmask for a specific batch index. @@ -119,7 +120,7 @@ def compile_grammar( """ @abstractmethod - def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor: + def allocate_token_bitmask(self, max_num_seqs: int) -> "torch.Tensor": """ Allocates a token bitmask for the specified maximum number of sequences. diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 9f81d09633d7..c9f2dc07da78 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import json from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any @@ -13,7 +11,7 @@ from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer -from vllm.utils import LazyLoader +from vllm.utils.import_utils import LazyLoader from vllm.v1.structured_output.backend_types import ( StructuredOutputBackend, StructuredOutputGrammar, @@ -43,34 +41,13 @@ def __post_init__(self): if isinstance(self.tokenizer, MistralTokenizer): # NOTE: ideally, xgrammar should handle this accordingly. # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 - try: - if self.tokenizer.is_tekken: - encoded_vocab = self.tokenizer._vocab - else: - encoded_vocab = [ - token - for token, _ in sorted( - self.tokenizer.get_vocab().items(), - key=lambda x: x[1], - ) - ] - stop_token_ids = None - if ( - hasattr( - self.tokenizer, - "eos_token_id", - ) - and self.tokenizer.eos_token_id is not None - ): - stop_token_ids = [self.tokenizer.eos_token_id] - except AttributeError as e: - raise ValueError( - f"Cannot get the vocabulary of the tokenizer " - f"{type(self.tokenizer)}. The tokenizer should have a " - "get_vocab method." - ) from e + stop_token_ids = [self.tokenizer.eos_token_id] + + # not self.tokenizer.vocab_size as self.tokenizer.vocab + # collapses all decoded errors into a single token. + self.vocab_size = len(self.tokenizer.vocab) tokenizer_info = xgr.TokenizerInfo( # type: ignore - encoded_vocab=encoded_vocab, + encoded_vocab=self.tokenizer.vocab, # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 vocab_type=xgr.VocabType.RAW if self.tokenizer.is_tekken @@ -114,18 +91,19 @@ def compile_grammar( ctx = self.compiler.compile_regex(grammar_spec) elif request_type == StructuredOutputOptions.STRUCTURAL_TAG: s_tag = json.loads(grammar_spec) - tags = [ - xgr.StructuralTagItem( - begin=s["begin"], - schema=json.dumps(s["schema"]), - end=s["end"], - ) - for s in s_tag["structures"] - ] - structural_tag = xgr.StructuralTag.from_legacy_structural_tag( - tags, s_tag["triggers"] - ) - ctx = self.compiler.compile_structural_tag(structural_tag) + if "structures" in s_tag: + # Falling back to deprecated method of compiling structural tag + tags = [ + xgr.StructuralTagItem( + begin=s["begin"], + schema=json.dumps(s["schema"]), + end=s["end"], + ) + for s in s_tag["structures"] + ] + ctx = self.compiler.compile_structural_tag(tags, s_tag["triggers"]) + else: + ctx = self.compiler.compile_structural_tag(grammar_spec) else: logger.error( "Validation should have already occurred. Please file an issue." @@ -221,6 +199,25 @@ def reset(self): self.matcher.reset() +# cf https://github.com/mlc-ai/xgrammar/blob/a32ac892676d2eedc0327416105b9b06edfb94b2/cpp/json_schema_converter.cc +STRING_SUPPORTED_FORMATS = { + "email", + "date", + "time", + "date-time", + "duration", + "ipv4", + "ipv6", + "hostname", + "uuid", + "uri", + "uri-reference", + "uri-template", + "json-pointer", + "relative-json-pointer", +} + + def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: """Check if JSON schema contains features unsupported by xgrammar.""" @@ -240,7 +237,11 @@ def check_object(obj: dict[str, Any]) -> bool: return True # Unsupported keywords for strings - if obj.get("type") == "string" and "format" in obj: + if ( + obj.get("type") == "string" + and "format" in obj + and obj["format"] not in STRING_SUPPORTED_FORMATS + ): return True # Unsupported keywords for objects @@ -343,17 +344,19 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: if so_params.structural_tag: try: s_tag = json.loads(so_params.structural_tag) - tags = [ - xgr.StructuralTagItem( - begin=s["begin"], - schema=json.dumps(s["schema"]), - end=s["end"], - ) - for s in s_tag["structures"] - ] - structural_tag = xgr.StructuralTag.from_legacy_structural_tag( - tags, s_tag["triggers"] - ) - xgr.Grammar.from_structural_tag(structural_tag) + + # Using the deprecated method of compiling structural tag + if "structures" in s_tag: + tags = [ + xgr.StructuralTagItem( + begin=s["begin"], + schema=json.dumps(s["schema"]), + end=s["end"], + ) + for s in s_tag["structures"] + ] + xgr.Grammar.from_structural_tag(tags, s_tag["triggers"]) + else: + xgr.Grammar.from_structural_tag(so_params.structural_tag) except Exception as e: raise ValueError("Invalid structural tag specification.") from e diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index 233c7c1e7805..94ae36a1abb4 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -1,15 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import dataclasses import functools import json from concurrent.futures import Future from concurrent.futures._base import TimeoutError -from typing import Optional, Union, cast +from typing import cast -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.v1.structured_output.backend_types import ( StructuredOutputGrammar, StructuredOutputKey, @@ -19,12 +17,24 @@ @dataclasses.dataclass class StructuredOutputRequest: - sampling_params: SamplingParams - _grammar: Union[Future[StructuredOutputGrammar], StructuredOutputGrammar] | None = ( - None - ) + params: StructuredOutputsParams + _grammar: Future[StructuredOutputGrammar] | StructuredOutputGrammar | None = None reasoning_ended: bool | None = None + @staticmethod + def from_sampling_params( + sampling_params: SamplingParams | None, + ) -> "StructuredOutputRequest | None": + if sampling_params is None: + return None + params = sampling_params.structured_outputs + if params: + if params.all_constraints_none(): + return None + else: + return StructuredOutputRequest(params=params) + return None + def _check_grammar_completion(self) -> bool: # NOTE: We have to lazy import to gate circular imports from vllm.v1.request import RequestStatus @@ -46,44 +56,39 @@ def is_grammar_ready(self) -> bool: def grammar(self) -> StructuredOutputGrammar | None: completed = self._check_grammar_completion() return ( - cast(Optional[StructuredOutputGrammar], self._grammar) - if completed - else None + cast(StructuredOutputGrammar | None, self._grammar) if completed else None ) @grammar.setter def grammar( - self, grammar: Union[StructuredOutputGrammar, Future[StructuredOutputGrammar]] + self, grammar: StructuredOutputGrammar | Future[StructuredOutputGrammar] ) -> None: self._grammar = grammar @functools.cached_property def structured_output_key(self) -> StructuredOutputKey: - return get_structured_output_key(self.sampling_params) + return get_structured_output_key(self.params) -def get_structured_output_key(sampling_params: SamplingParams) -> StructuredOutputKey: - params = sampling_params.structured_outputs - assert params is not None, "params can't be None." +def get_structured_output_key(params: StructuredOutputsParams) -> StructuredOutputKey: if params.json is not None: if not isinstance(params.json, str): json_str = json.dumps(params.json) else: json_str = params.json - return (StructuredOutputOptions.JSON, json_str) - elif params.json_object: - return (StructuredOutputOptions.JSON_OBJECT, "") - elif params.regex is not None: - return (StructuredOutputOptions.REGEX, params.regex) - elif params.choice is not None: + return StructuredOutputOptions.JSON, json_str + if params.json_object: + return StructuredOutputOptions.JSON_OBJECT, "" + if params.regex is not None: + return StructuredOutputOptions.REGEX, params.regex + if params.choice is not None: if not isinstance(params.choice, str): json_str = json.dumps(params.choice) else: json_str = params.choice - return (StructuredOutputOptions.CHOICE, json_str) - elif params.grammar is not None: - return (StructuredOutputOptions.GRAMMAR, params.grammar) - elif params.structural_tag is not None: - return (StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag) - else: - raise ValueError("No valid structured output parameter found") + return StructuredOutputOptions.CHOICE, json_str + if params.grammar is not None: + return StructuredOutputOptions.GRAMMAR, params.grammar + if params.structural_tag is not None: + return StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag + raise ValueError("No valid structured output parameter found") diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index b7326847d016..ef9bae2367be 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - from __future__ import annotations import hashlib @@ -16,7 +15,7 @@ import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import LazyLoader +from vllm.utils.import_utils import LazyLoader if TYPE_CHECKING: import outlines_core as oc @@ -37,6 +36,10 @@ "transformers.models.gpt2.tokenization_gpt2", ) + AnyTokenizer = object + SchedulerOutput = object + InputBatch = object + logger = init_logger(__name__) CACHE = None @@ -46,7 +49,6 @@ def apply_grammar_bitmask( scheduler_output: SchedulerOutput, input_batch: InputBatch, logits: torch.Tensor, - device: torch.device, ) -> None: """ Apply grammar bitmask to output logits of the model with xgrammar function. @@ -55,7 +57,6 @@ def apply_grammar_bitmask( scheduler_output (SchedulerOutput): The result of engine scheduling. input_batch (InputBatch): The input of model runner. logits (torch.Tensor): The output logits of model forward. - device (torch.device): The device that model runner running on. """ grammar_bitmask = scheduler_output.grammar_bitmask if grammar_bitmask is None: @@ -90,10 +91,7 @@ def apply_grammar_bitmask( dtype=grammar_bitmask.dtype, ) cumulative_index = 0 - seq = sorted( - scheduler_output.structured_output_request_ids.items(), key=lambda x: x[1] - ) - for req_id, _ in seq: + for req_id in scheduler_output.structured_output_request_ids: num_spec_tokens = len( scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) ) @@ -116,7 +114,7 @@ def apply_grammar_bitmask( xgr.apply_token_bitmask_inplace( logits, - grammar_bitmask.to(device, non_blocking=True), + grammar_bitmask.to(logits.device, non_blocking=True), indices=out_indices if not skip_out_indices else None, ) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 925943262894..789a74cc6c4a 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -5,14 +5,13 @@ import multiprocessing import time import weakref -from collections.abc import Sequence +from collections.abc import Callable, Sequence from contextlib import AbstractContextManager from multiprocessing import connection from multiprocessing.process import BaseProcess from typing import ( TYPE_CHECKING, Any, - Callable, Generic, Optional, TypeVar, @@ -26,12 +25,8 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message -from vllm.utils import ( - get_open_port, - get_open_zmq_ipc_path, - get_tcp_uri, - kill_process_tree, -) +from vllm.utils import kill_process_tree +from vllm.utils.network_utils import get_open_port, get_open_zmq_ipc_path, get_tcp_uri if TYPE_CHECKING: import numpy as np @@ -66,7 +61,7 @@ def remove(self, item): def clear(self): raise TypeError("Cannot clear a constant list") - def index(self, item: T, start: int = 0, stop: Optional[int] = None) -> int: + def index(self, item: T, start: int = 0, stop: int | None = None) -> int: return self._x.index(item, start, stop if stop is not None else len(self._x)) @overload @@ -75,7 +70,7 @@ def __getitem__(self, item: int) -> T: ... @overload def __getitem__(self, s: slice, /) -> list[T]: ... - def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]: + def __getitem__(self, item: int | slice) -> T | list[T]: return self._x[item] @overload @@ -84,7 +79,7 @@ def __setitem__(self, item: int, value: T): ... @overload def __setitem__(self, s: slice, value: T, /): ... - def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]): + def __setitem__(self, item: int | slice, value: T | list[T]): raise TypeError("Cannot set item in a constant list") def __delitem__(self, item): @@ -108,7 +103,7 @@ class CpuGpuBuffer: def __init__( self, - *size: Union[int, torch.SymInt], + *size: int | torch.SymInt, dtype: torch.dtype, device: torch.device, pin_memory: bool, @@ -128,12 +123,12 @@ def __init__( ) self.np = self.cpu.numpy() - def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor: + def copy_to_gpu(self, n: int | None = None) -> torch.Tensor: if n is None: return self.gpu.copy_(self.cpu, non_blocking=True) return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True) - def copy_to_cpu(self, n: Optional[int] = None) -> torch.Tensor: + def copy_to_cpu(self, n: int | None = None) -> torch.Tensor: """NOTE: Because this method is non-blocking, explicit synchronization is needed to ensure the data is copied to CPU.""" if n is None: @@ -173,7 +168,7 @@ def __init__( num_servers: int, input_addresses: list[str], output_addresses: list[str], - stats_update_address: Optional[str] = None, + stats_update_address: str | None = None, ): """Initialize and start API server worker processes. @@ -227,9 +222,8 @@ def close(self) -> None: def wait_for_completion_or_failure( api_server_manager: APIServerProcessManager, - engine_manager: Optional[ - Union["CoreEngineProcManager", "CoreEngineActorManager"] - ] = None, + engine_manager: Union["CoreEngineProcManager", "CoreEngineActorManager"] + | None = None, coordinator: Optional["DPCoordinator"] = None, ) -> None: """Wait for all processes to complete or detect if any fail. @@ -347,13 +341,17 @@ def report_usage_stats( parallel_config = vllm_config.parallel_config + # Prepare KV connector string if applicable + kv_connector = None + if vllm_config.kv_transfer_config is not None: + kv_connector = vllm_config.kv_transfer_config.kv_connector + usage_message.report_usage( get_architecture_class_name(vllm_config.model_config), usage_context, extra_kvs={ # Common configuration "dtype": str(vllm_config.model_config.dtype), - "tensor_parallel_size": parallel_config.tensor_parallel_size, "block_size": vllm_config.cache_config.block_size, "gpu_memory_utilization": vllm_config.cache_config.gpu_memory_utilization, "kv_cache_memory_bytes": vllm_config.cache_config.kv_cache_memory_bytes, @@ -365,6 +363,15 @@ def report_usage_stats( "enable_prefix_caching": vllm_config.cache_config.enable_prefix_caching, "enforce_eager": vllm_config.model_config.enforce_eager, "disable_custom_all_reduce": parallel_config.disable_custom_all_reduce, + # Distributed parallelism settings + "tensor_parallel_size": parallel_config.tensor_parallel_size, + "data_parallel_size": parallel_config.data_parallel_size, + "pipeline_parallel_size": parallel_config.pipeline_parallel_size, + "enable_expert_parallel": parallel_config.enable_expert_parallel, + # All2All backend for MoE expert parallel + "all2all_backend": parallel_config.all2all_backend, + # KV connector used + "kv_connector": kv_connector, }, ) @@ -389,3 +396,16 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager: _PROFILER_FUNC = func return func(name) + + +def tensor_data(tensor: torch.Tensor) -> memoryview: + """Get the raw data of a tensor as a uint8 memoryview, useful for + serializing and hashing. + + Args: + tensor: The input tensor. + + Returns: + A memoryview of the tensor data as uint8. + """ + return tensor.flatten().contiguous().view(torch.uint8).numpy().data diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 4d3688453cb9..9bf06d51609f 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Union import numpy as np import torch @@ -22,22 +21,64 @@ def __init__( max_num_batched_tokens: int, pin_memory: bool, device: torch.device, + kernel_block_size: int, ): - self.block_size = block_size + """ + Args: + block_size: Block size used for KV cache memory allocation + max_num_reqs: Maximum number of concurrent requests supported. + max_num_blocks_per_req: Maximum number of blocks per request. + max_num_batched_tokens: Maximum number of tokens in a batch. + pin_memory: Whether to pin memory for faster GPU transfers. + device: Target device for the block table. + kernel_block_size: The block_size of underlying attention kernel. + Will be the same as `block_size` if `block_size` is supported + by the attention kernel. + """ self.max_num_reqs = max_num_reqs - self.max_num_blocks_per_req = max_num_blocks_per_req self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory self.device = device + if kernel_block_size == block_size: + # Standard case: allocation and computation use same block size + # No block splitting needed, direct mapping + self.block_size = block_size + self.blocks_per_kv_block = 1 + self.use_hybrid_blocks = False + else: + # Hybrid case: allocation block size differs from kernel block size + # Memory blocks are subdivided to match kernel requirements + # Example: 32-token memory blocks with 16-token kernel blocks + # → Each memory block corresponds to 2 kernel blocks + if block_size % kernel_block_size != 0: + raise ValueError( + f"kernel_block_size {kernel_block_size} must divide " + f"kv_manager_block_size size {block_size} evenly" + ) + + self.block_size = kernel_block_size + self.blocks_per_kv_block = block_size // kernel_block_size + self.use_hybrid_blocks = True + + self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block + self.block_table = self._make_buffer( - max_num_reqs, max_num_blocks_per_req, dtype=torch.int32 + self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32 ) self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) self.slot_mapping = self._make_buffer( self.max_num_batched_tokens, dtype=torch.int64 ) + + if self.use_hybrid_blocks: + self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape( + 1, -1 + ) + else: + self._kernel_block_arange = None + try: self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group @@ -53,6 +94,10 @@ def append_row( ) -> None: if not block_ids: return + + if self.use_hybrid_blocks: + block_ids = self._map_to_kernel_blocks(np.array(block_ids)) + num_blocks = len(block_ids) start = self.num_blocks_per_row[row_idx] self.num_blocks_per_row[row_idx] += num_blocks @@ -94,6 +139,7 @@ def compute_slot_mapping( req_indices * self.max_num_blocks_per_req + positions // virtual_block_size ) + block_numbers = self.block_table.np.ravel()[block_table_indices] # Use virtual_block_size for mask calculation, which marks local # tokens. @@ -111,6 +157,7 @@ def compute_slot_mapping( block_table_indices = ( req_indices * self.max_num_blocks_per_req + positions // self.block_size ) + block_numbers = self.block_table.np.ravel()[block_table_indices] block_offsets = positions % self.block_size np.add( @@ -129,6 +176,31 @@ def clear(self) -> None: self.block_table.gpu.fill_(0) self.block_table.cpu.fill_(0) + def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray: + """Convert kv_manager_block_id IDs to kernel block IDs. + + Example: + # kv_manager_block_ids: 32 tokens, + # Kernel block size: 16 tokens + # blocks_per_kv_block = 2 + >>> kv_manager_block_ids = np.array([0, 1, 2]) + >>> Result: [0, 1, 2, 3, 4, 5] + + # Each kv_manager_block_id maps to 2 kernel block id: + # kv_manager_block_id 0 → kernel block id [0, 1] + # kv_manager_block_id 1 → kernel block id [2, 3] + # kv_manager_block_id 2 → kernel block id [4, 5] + """ + if not self.use_hybrid_blocks: + return kv_manager_block_ids + + kernel_block_ids = ( + kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block + + self._kernel_block_arange + ) + + return kernel_block_ids.reshape(-1) + def get_device_tensor(self, num_reqs: int) -> torch.Tensor: """Returns the device tensor of the block table.""" return self.block_table.gpu[:num_reqs] @@ -142,7 +214,7 @@ def get_numpy_array(self) -> np.ndarray: return self.block_table.np def _make_buffer( - self, *size: Union[int, torch.SymInt], dtype: torch.dtype + self, *size: int | torch.SymInt, dtype: torch.dtype ) -> CpuGpuBuffer: return CpuGpuBuffer( *size, dtype=dtype, device=self.device, pin_memory=self.pin_memory @@ -160,6 +232,7 @@ def __init__( pin_memory: bool, device: torch.device, block_sizes: list[int], + kernel_block_sizes: list[int], num_speculative_tokens: int = 0, ) -> None: # Note(hc): each dcp rank only store @@ -172,6 +245,12 @@ def __init__( # DCP might not be initialized in testing dcp_world_size = 1 + if len(kernel_block_sizes) != len(block_sizes): + raise ValueError( + f"kernel_block_sizes length ({len(kernel_block_sizes)}) " + f"must match block_sizes length ({len(block_sizes)})" + ) + self.block_tables = [ BlockTable( block_size, @@ -183,8 +262,9 @@ def __init__( max_num_batched_tokens, pin_memory, device, + kernel_block_size, ) - for block_size in block_sizes + for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes) ] def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 299567427027..5aebfec06dfd 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import torch import torch.nn as nn @@ -95,7 +95,7 @@ def _sync_device(self) -> None: def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: return sampled_token_ids.tolist() - def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + def get_dp_padding(self, num_tokens: int) -> tuple[int, torch.Tensor | None]: # Note: For CPU backend, dp padding is not required for now. return 0, None diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index ee865ec8e649..5b57df2d472c 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os import platform -from typing import Callable, Optional +from collections.abc import Callable import torch @@ -91,7 +91,7 @@ def sleep(self, level: int = 1) -> None: logger.warning("sleep mode is not supported on CPU, ignore it.") pass - def wake_up(self, tags: Optional[list[str]] = None) -> None: + def wake_up(self, tags: list[str] | None = None) -> None: logger.warning("sleep mode is not supported on CPU, ignore it.") pass @@ -128,7 +128,7 @@ def _get_autobind_cpu_ids( "Please try to bind threads manually." ) - # Get CPUs on NUMA node `allowed_numa_nodes[local_rank]`` + # Get CPUs on NUMA node `allowed_numa_nodes[local_rank]` selected_numa_node = allowed_numa_nodes[self.local_rank] # type: ignore logical_cpu_list = [ x for x in logical_cpu_list if x.numa_node == selected_numa_node diff --git a/vllm/v1/worker/dp_utils.py b/vllm/v1/worker/dp_utils.py index 7a943909a8ba..2b2a69f4af3a 100644 --- a/vllm/v1/worker/dp_utils.py +++ b/vllm/v1/worker/dp_utils.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import numpy as np import torch @@ -37,6 +36,7 @@ def _get_device_and_group(parallel_config: ParallelConfig): def _run_ar( should_ubatch: bool, + should_dp_pad: bool, orig_num_tokens_per_ubatch: int, padded_num_tokens_per_ubatch: int, parallel_config: ParallelConfig, @@ -44,10 +44,11 @@ def _run_ar( dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank device, group = _get_device_and_group(parallel_config) - tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32) + tensor = torch.zeros(4, dp_size, device=device, dtype=torch.int32) tensor[0][dp_rank] = orig_num_tokens_per_ubatch tensor[1][dp_rank] = padded_num_tokens_per_ubatch tensor[2][dp_rank] = 1 if should_ubatch else 0 + tensor[3][dp_rank] = 1 if should_dp_pad else 0 dist.all_reduce(tensor, group=group) return tensor @@ -72,68 +73,115 @@ def _post_process_ubatch(tensor: torch.Tensor) -> bool: return should_ubatch +def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch.Tensor: + num_tokens_across_dp = tensor[1, :] + if should_dp_pad: + # If DP padding is enabled, ensure that each rank is processing the same number + # of tokens + max_num_tokens = int(num_tokens_across_dp.max().item()) + return torch.tensor( + [max_num_tokens] * len(num_tokens_across_dp), + device="cpu", + dtype=torch.int32, + ) + else: + return num_tokens_across_dp.cpu() + + def _synchronize_dp_ranks( num_tokens_unpadded: int, num_tokens_padded: int, should_attempt_ubatching: bool, + should_attempt_dp_padding: bool, parallel_config: ParallelConfig, -) -> tuple[bool, Optional[torch.Tensor]]: +) -> tuple[bool, torch.Tensor | None]: """ 1. Decides if each DP rank is going to microbatch. Either all ranks run with microbatching or none of them do. 2. Determines the total number of tokens that each rank will run. - All ranks will be padded out so that the run with the same number - of tokens + When running microbatched or if should_attempt_dp_padding is True, all + ranks will be padded out so that the run with the same number of tokens Returns: tuple[ should_ubatch: Are all DP ranks going to microbatch num_tokens_after_padding: A tensor containing the total number of - tokens per-microbatch for each DP rank including padding. + tokens per-microbatch for each DP rank including any DP padding. ] """ assert num_tokens_padded >= num_tokens_unpadded - # First we coordinate between the DP ranks via an All Reduce + # Coordinate between the DP ranks via an All Reduce # to determine the total number of tokens that each rank # will run and if we are using ubatching or not. tensor = _run_ar( should_ubatch=should_attempt_ubatching, + should_dp_pad=should_attempt_dp_padding, orig_num_tokens_per_ubatch=num_tokens_unpadded, padded_num_tokens_per_ubatch=num_tokens_padded, parallel_config=parallel_config, ) - # Ensure that each rank is processing the same nuber of tokens - num_tokens_across_dp = tensor[1, :] - max_num_tokens = int(num_tokens_across_dp.max().item()) - num_tokens_after_padding = torch.tensor( - [max_num_tokens] * len(num_tokens_across_dp), device="cpu", dtype=torch.int32 - ) + should_dp_pad = bool(torch.all(tensor[3] == 1).item()) + + # DP ranks should all have the same value for should_attempt_dp_padding. + assert should_attempt_dp_padding == should_dp_pad + # Check conditions for microbatching should_ubatch = _post_process_ubatch(tensor) + if should_ubatch and not should_dp_pad: + logger.debug_once( + "Microbatching has been triggered and requires DP padding. " + "Enabling DP padding even though it has been explicitly " + "disabled.", + scope="global", + ) + should_dp_pad = True + + # Pad all DP ranks up to the maximum token count across ranks if + # should_dp_pad is True + num_tokens_after_padding = _post_process_dp_padding( + tensor, + should_dp_pad, + ) + return should_ubatch, num_tokens_after_padding def coordinate_batch_across_dp( - num_scheduled_tokens_per_request: np.ndarray, num_tokens_unpadded: int, - num_tokens_padded: int, - parallel_config: ParallelConfig, allow_microbatching: bool, - uniform_decode: bool, -) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]: + allow_dp_padding: bool, + parallel_config: ParallelConfig, + num_tokens_padded: int | None = None, + uniform_decode: bool | None = None, + num_scheduled_tokens_per_request: np.ndarray | None = None, +) -> tuple[UBatchSlices | None, torch.Tensor | None]: """ Coordinates amongst all DP ranks to determine if and how the full batch should be split into microbatches. + Args: + num_tokens_unpadded: Number of tokens without accounting for padding + allow_microbatching: If microbatching should be attempted + allow_dp_padding: If all DP ranks should be padded up to the same value + parallel_config: The parallel config + num_tokens_padded: Number of tokens including any non-DP padding (CUDA graphs, + TP, etc) + uniform_decode: Only used if allow_microbatching is True. True if the batch + only contains single token decodes + num_scheduled_tokens_per_request: Only used if allow_microbatching is True. The + number of tokens per request. + Returns: tuple[ ubatch_slices: if this is set then all DP ranks have agreed to microbatch num_tokens_after_padding: A tensor containing the total number of - tokens per-microbatch for each DP rank including padding. + tokens per-microbatch for each DP rank including padding. Will be + padded up to the max value across all DP ranks when allow_dp_padding + is True. ] """ @@ -141,21 +189,25 @@ def coordinate_batch_across_dp( # Early exit. return None, None - # Check preconditions for microbatching - should_attempt_ubatching = check_ubatch_thresholds( - parallel_config, - num_tokens_unpadded, - uniform_decode=uniform_decode, - ) + # If the caller has explicitly enabled microbatching. + should_attempt_ubatching = False + if allow_microbatching: + # Check preconditions for microbatching + assert uniform_decode is not None + should_attempt_ubatching = check_ubatch_thresholds( + parallel_config, + num_tokens_unpadded, + uniform_decode=uniform_decode, + ) - # If the caller has explicitly disabled microbatching. - if not allow_microbatching: - should_attempt_ubatching = False + if num_tokens_padded is None: + num_tokens_padded = num_tokens_unpadded (should_ubatch, num_tokens_after_padding) = _synchronize_dp_ranks( num_tokens_unpadded, num_tokens_padded, should_attempt_ubatching, + allow_dp_padding, parallel_config, ) @@ -170,6 +222,7 @@ def coordinate_batch_across_dp( assert num_tokens_after_padding is not None token_split_point = int(num_tokens_after_padding[0].item()) // 2 + assert num_scheduled_tokens_per_request is not None ubatch_slices = create_ubatch_slices( num_scheduled_tokens_per_request, token_split_point ) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 22f5c6f7e683..476c3edefb84 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -3,7 +3,7 @@ # Datastructures defining a GPU input batch from dataclasses import dataclass -from typing import Optional, cast +from typing import cast import numpy as np import torch @@ -12,7 +12,8 @@ from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values +from vllm.utils import length_from_prompt_token_ids_or_embeds +from vllm.utils.collection_utils import swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import ( @@ -29,21 +30,21 @@ @dataclass class CachedRequestState: req_id: str - prompt_token_ids: Optional[list[int]] + prompt_token_ids: list[int] | None mm_features: list[MultiModalFeatureSpec] - sampling_params: Optional[SamplingParams] - pooling_params: Optional[PoolingParams] - generator: Optional[torch.Generator] + sampling_params: SamplingParams | None + pooling_params: PoolingParams | None + generator: torch.Generator | None block_ids: tuple[list[int], ...] num_computed_tokens: int output_token_ids: list[int] - mrope_positions: Optional[torch.Tensor] = None - mrope_position_delta: Optional[int] = None + mrope_positions: torch.Tensor | None = None + mrope_position_delta: int | None = None - lora_request: Optional[LoRARequest] = None - prompt_embeds: Optional[torch.Tensor] = None + lora_request: LoRARequest | None = None + prompt_embeds: torch.Tensor | None = None def __post_init__(self): self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( @@ -62,10 +63,9 @@ def get_token_id(self, idx: int) -> int: "provided via prompt_embeds, and its ID is unknown." ) return self.prompt_token_ids[idx] - elif idx - self.num_prompt_tokens < len(self.output_token_ids): + if idx - self.num_prompt_tokens < len(self.output_token_ids): return self.output_token_ids[idx - self.num_prompt_tokens] - else: - return -1 + return -1 class InputBatch: @@ -78,7 +78,9 @@ def __init__( pin_memory: bool, vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group - logitsprocs: Optional[LogitsProcessors] = None, + kernel_block_sizes: list[int], + logitsprocs: LogitsProcessors | None = None, + logitsprocs_need_output_token_ids: bool = False, is_spec_decode: bool = False, is_pooling_model: bool = False, num_speculative_tokens: int = 0, @@ -92,7 +94,7 @@ def __init__( self.pin_memory = pin_memory self.vocab_size = vocab_size - self._req_ids: list[Optional[str]] = [] + self._req_ids: list[str | None] = [] self.req_id_to_index: dict[str, int] = {} # TODO(woosuk): This buffer could be too large if max_model_len is big. @@ -132,6 +134,7 @@ def __init__( pin_memory=pin_memory, device=device, block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, num_speculative_tokens=num_speculative_tokens, ) @@ -226,22 +229,23 @@ def __init__( self.has_allowed_token_ids: set[str] = set() # NOTE(lufang): In the mask tensor, if the corresponding token allowed, # the value is False. Since we use masked_fill_ to set -inf. - self.allowed_token_ids_mask: Optional[torch.Tensor] = None - self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None + self.allowed_token_ids_mask: torch.Tensor | None = None + self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None # req_index -> bad_words_token_ids self.bad_words_token_ids: dict[int, list[list[int]]] = {} self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool) - self.req_output_token_ids: list[Optional[list[int]]] = [] + self.req_output_token_ids: list[list[int] | None] = [] # Store provided logitsprocs. If none are provided, initialize empty # data structure self.logitsprocs = logitsprocs or LogitsProcessors() + self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids # Store last speculative tokens for sampler. - self.spec_token_ids: list[Optional[list[int]]] = [] + self.spec_token_ids: list[list[int] | None] = [] # This is updated each time the batch constituents change. self.sampling_metadata = self._make_sampling_metadata() @@ -249,9 +253,13 @@ def __init__( self.pooling_params: dict[str, PoolingParams] = {} # Cached reference to the GPU tensor of previously sampled tokens - self.prev_sampled_token_ids: Optional[torch.Tensor] = None - self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None - self.prev_req_id_to_index: Optional[dict[str, int]] = None + self.prev_sampled_token_ids: torch.Tensor | None = None + self.prev_req_id_to_index: dict[str, int] | None = None + # These are used to update output_token_ids with real sampled + # ids from prior step, if required by current sampling params + # (e.g. penalties). + self.sampled_token_ids_cpu: torch.Tensor | None = None + self.async_copy_ready_event: torch.cuda.Event | None = None @property def req_ids(self) -> list[str]: @@ -431,7 +439,7 @@ def add_request( return req_index - def remove_request(self, req_id: str) -> Optional[int]: + def remove_request(self, req_id: str) -> int | None: """This method must always be followed by a call to condense(). Args: @@ -768,16 +776,28 @@ def _make_sampling_metadata(self) -> SamplingMetadata: not self.no_penalties or self.logits_processing_needs_token_ids[:num_reqs].any() ) - if needs_prompt_token_ids: - # The prompt tokens are used only for applying penalties or - # step pooling during the sampling/pooling process. - # Hence copy these tensors only when there are requests which - # need penalties/step_pooler to be applied. - prompt_token_ids = self._make_prompt_token_ids_tensor() - else: - prompt_token_ids = None + # The prompt tokens are used only for applying penalties or + # step pooling during the sampling/pooling process. + # Hence copy these tensors only when there are requests which + # need penalties/step_pooler to be applied. + prompt_token_ids = ( + self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None + ) + + # Only set output_token_ids if required by the current requests' + # sampling parameters. + needs_output_token_ids = ( + not self.no_penalties + or bool(self.bad_words_token_ids) + or self.logitsprocs_need_output_token_ids + ) + output_token_ids = ( + cast(list[list[int]], self.req_output_token_ids) + if needs_output_token_ids + else [] + ) - allowed_token_ids_mask: Optional[torch.Tensor] = None + allowed_token_ids_mask: torch.Tensor | None = None if not self.no_allowed_token_ids: assert self.allowed_token_ids_mask is not None copy_slice( @@ -799,7 +819,7 @@ def _make_sampling_metadata(self) -> SamplingMetadata: frequency_penalties=self.frequency_penalties[:num_reqs], presence_penalties=self.presence_penalties[:num_reqs], repetition_penalties=self.repetition_penalties[:num_reqs], - output_token_ids=cast(list[list[int]], self.req_output_token_ids), + output_token_ids=output_token_ids, spec_token_ids=cast(list[list[int]], self.spec_token_ids), no_penalties=self.no_penalties, allowed_token_ids_mask=allowed_token_ids_mask, @@ -860,6 +880,52 @@ def make_lora_inputs( return prompt_lora_mapping, token_lora_mapping, active_lora_requests + def set_async_sampled_token_ids( + self, + sampled_token_ids_cpu: torch.Tensor, + async_copy_ready_event: torch.cuda.Event, + ) -> None: + """ + In async scheduling case, store ref to sampled_token_ids_cpu + tensor and corresponding copy-ready event. Used to repair + output_token_ids prior to sampling, if needed by logits processors. + """ + if self.sampling_metadata.output_token_ids: + self.sampled_token_ids_cpu = sampled_token_ids_cpu + self.async_copy_ready_event = async_copy_ready_event + else: + self.sampled_token_ids_cpu = None + self.async_copy_ready_event = None + + def update_async_output_token_ids(self) -> None: + """ + In async scheduling case, update output_token_ids in sampling metadata + from prior steps sampled token ids once they've finished copying to CPU. + This is called right before they are needed by the logits processors. + """ + output_token_ids = self.sampling_metadata.output_token_ids + if self.sampled_token_ids_cpu is None or not output_token_ids: + # Output token ids not needed or not async scheduling. + return + + assert self.prev_req_id_to_index is not None + sampled_token_ids = None + for index, req_id in enumerate(self.req_ids): + prev_index = self.prev_req_id_to_index.get(req_id) + if prev_index is None: + continue + req_output_token_ids = output_token_ids[index] + if not req_output_token_ids or req_output_token_ids[-1] != -1: + # Final output id is not a placeholder, some tokens must have + # been discarded after a kv-load failure. + continue + if sampled_token_ids is None: + assert self.async_copy_ready_event is not None + self.async_copy_ready_event.synchronize() + sampled_token_ids = self.sampled_token_ids_cpu.squeeze(-1).tolist() + # Replace placeholder token id with actual sampled id. + req_output_token_ids[-1] = sampled_token_ids[prev_index] + @property def num_reqs(self) -> int: return len(self.req_id_to_index) @@ -889,7 +955,7 @@ def no_penalties(self) -> bool: ) @property - def max_num_logprobs(self) -> Optional[int]: + def max_num_logprobs(self) -> int | None: return max(self.num_logprobs.values()) if self.num_logprobs else None @property diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b7dc2287b79f..31429fe699a2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -8,24 +8,23 @@ from collections.abc import Iterator from contextlib import contextmanager from copy import deepcopy -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, cast +from itertools import product +from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast import numpy as np import torch import torch.distributed import torch.nn as nn from tqdm import tqdm -from typing_extensions import TypeAlias import vllm.envs as envs from vllm.attention import Attention, AttentionType -from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention +from vllm.attention.backends.abstract import AttentionBackend, MultipleOf from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.config import ( - CompilationLevel, + CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, @@ -44,10 +43,8 @@ from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader -from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.interfaces import ( SupportsMultiModal, is_mixture_of_experts, @@ -73,18 +70,20 @@ from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import ( - STR_DTYPE_TO_TORCH_DTYPE, - DeviceMemoryProfiler, - GiB_bytes, cdiv, check_use_alibi, - get_dtype_size, - is_pin_memory_available, length_from_prompt_token_ids_or_embeds, round_up, - supports_dynamo, ) from vllm.utils.jsontree import json_map_leaves +from vllm.utils.mem_constants import GiB_bytes +from vllm.utils.mem_utils import DeviceMemoryProfiler +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.utils.torch_utils import ( + get_dtype_size, + kv_cache_dtype_str_to_dtype, + supports_dynamo, +) from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( @@ -106,7 +105,6 @@ KVCacheGroupSpec, KVCacheSpec, MambaSpec, - MLAAttentionSpec, SlidingWindowSpec, UniformTypeKVCacheSpecs, ) @@ -161,7 +159,7 @@ AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] # list when ubatching is enabled -PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], AttnMetadataDict] +PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict # Wrapper for ModelRunnerOutput to support overlapped execution. @@ -177,7 +175,7 @@ def __init__( self._invalid_req_indices = invalid_req_indices # Event on the copy stream so we can synchronize the non-blocking copy. - self._async_copy_ready_event = torch.cuda.Event() + self.async_copy_ready_event = torch.cuda.Event() # Keep a reference to the device tensor to avoid it being # deallocated until we finish copying it to the host. @@ -187,22 +185,22 @@ def __init__( default_stream = torch.cuda.current_stream() with torch.cuda.stream(async_output_copy_stream): async_output_copy_stream.wait_stream(default_stream) - self._sampled_token_ids_cpu = self._sampled_token_ids.to( + self.sampled_token_ids_cpu = self._sampled_token_ids.to( "cpu", non_blocking=True ) - self._async_copy_ready_event.record() + self.async_copy_ready_event.record() def get_output(self) -> ModelRunnerOutput: """Copy the device tensors to the host and return a ModelRunnerOutput. This function blocks until the copy is finished. """ - self._async_copy_ready_event.synchronize() + self.async_copy_ready_event.synchronize() # Release the device tensor once the copy has completed del self._sampled_token_ids - valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist() + valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() for i in self._invalid_req_indices: valid_sampled_token_ids[i].clear() @@ -231,9 +229,6 @@ def __init__( from vllm.model_executor.models.utils import set_cpu_offload_max_bytes set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3)) - from vllm.model_executor.layers.batch_invariant import init_batch_invariance - - init_batch_invariance() model_config = self.model_config cache_config = self.cache_config @@ -242,10 +237,9 @@ def __init__( self.device = device self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype - if cache_config.cache_dtype == "auto": - self.kv_cache_dtype = self.dtype - else: - self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + self.kv_cache_dtype = kv_cache_dtype_str_to_dtype( + cache_config.cache_dtype, self.model_config + ) self.is_pooling_model = model_config.runner_type == "pooling" self.enable_prompt_embeds = model_config.enable_prompt_embeds @@ -294,7 +288,7 @@ def __init__( # Sampler self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) - self.eplb_state: Optional[EplbState] = None + self.eplb_state: EplbState | None = None """ State of the expert parallelism load balancer. @@ -333,7 +327,7 @@ def __init__( "Unknown speculative decoding method: " f"{self.speculative_config.method}" ) - self.rejection_sampler = RejectionSampler() + self.rejection_sampler = RejectionSampler(self.sampler) # Request states. self.requests: dict[str, CachedRequestState] = {} @@ -348,6 +342,7 @@ def __init__( # solution, we initialize the input batch here, and re-initialize it # in `initialize_kv_cache` if the block_sizes here is different from # the block_sizes in the kv cache config. + custom_logitsprocs = model_config.logits_processors self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, # We need to use the encoder length for encoder-decoer @@ -358,32 +353,39 @@ def __init__( pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.cache_config.block_size], + kernel_block_sizes=[self.cache_config.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( self.vllm_config, self.device, self.pin_memory, self.is_pooling_model, - self.vllm_config.model_config.logits_processors, + custom_logitsprocs, ), + # We currently don't know whether a particular custom logits processor + # uses output token ids so we set this conservatively. + logitsprocs_need_output_token_ids=bool(custom_logitsprocs), is_pooling_model=self.is_pooling_model, ) self.use_async_scheduling = self.scheduler_config.async_scheduling - self.async_output_copy_stream = ( - torch.cuda.Stream() if self.use_async_scheduling else None - ) + # Separate cuda stream for overlapping transfer of sampled token ids from + # GPU to CPU when async scheduling is enabled. + self.async_output_copy_stream: torch.cuda.Stream | None = None + # cuda event to synchronize use of reused CPU tensors between steps + # when async scheduling is enabled. + self.prepare_inputs_event: torch.cuda.Event | None = None + if self.use_async_scheduling: + self.async_output_copy_stream = torch.cuda.Stream() + self.prepare_inputs_event = torch.cuda.Event() - # TODO(woosuk): Provide an option to tune the max cudagraph batch size. - # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. - # The batch sizes in the config are in descending order. if ( self.compilation_config.cudagraph_capture_sizes and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE ): - self.cudagraph_batch_sizes = list( - reversed(self.compilation_config.cudagraph_capture_sizes) + self.cudagraph_batch_sizes = sorted( + self.compilation_config.cudagraph_capture_sizes ) # Cache the device properties. @@ -396,6 +398,10 @@ def __init__( self.max_num_reqs + 1, dtype=torch.int32 ) self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + if self.dcp_world_size > 1: + self.dcp_local_seq_lens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) # Because inputs_embeds may be bfloat16 and we don't need a numpy # version of this tensor, avoid a RuntimeError by not creating a # numpy buffer. @@ -435,16 +441,8 @@ def __init__( (3, self.max_num_tokens + 1), dtype=torch.int64 ) - # CUDA event to synchronize use of reused CPU tensors between steps - # when async scheduling is enabled. - self.prepare_inputs_event: Optional[torch.cuda.Event] = None - if self.use_async_scheduling: - self.prepare_inputs_event = torch.cuda.Event() - # Start in a completed state. - self.prepare_inputs_event.record(torch.cuda.default_stream()) - # None in the first PP rank. The rest are set after load_model. - self.intermediate_tensors: Optional[IntermediateTensors] = None + self.intermediate_tensors: IntermediateTensors | None = None # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context @@ -485,7 +483,7 @@ def __init__( else None ) - self.reorder_batch_threshold: Optional[int] = None + self.reorder_batch_threshold: int | None = None # Attention layers that are only in the KVCacheConfig of the runner # (e.g., KV sharing, encoder-only attention), but not in the @@ -493,7 +491,7 @@ def __init__( self.runner_only_attn_layers: set[str] = set() # Cached outputs. - self._draft_token_ids: Optional[Union[list[list[int]], torch.Tensor]] = None + self._draft_token_ids: list[list[int]] | torch.Tensor | None = None self.transfer_event = torch.cuda.Event() self.sampled_token_ids_pinned_cpu = torch.empty( (self.max_model_len, 1), @@ -502,6 +500,10 @@ def __init__( pin_memory=self.pin_memory, ) + def reset_mm_cache(self) -> None: + if self.mm_budget: + self.mm_budget.reset_cache() + def _get_positions(self, num_tokens: Any): if isinstance(num_tokens, int): if self.uses_mrope: @@ -513,7 +515,7 @@ def _get_positions(self, num_tokens: Any): return self.positions.gpu[num_tokens] def _make_buffer( - self, *size: Union[int, torch.SymInt], dtype: torch.dtype, numpy: bool = True + self, *size: int | torch.SymInt, dtype: torch.dtype, numpy: bool = True ) -> CpuGpuBuffer: return CpuGpuBuffer( *size, @@ -579,7 +581,10 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: # NOTE(lucas): currently no backend supports the custom masking # required for DCP with q_len > 1, so we assert here. Remove this # assert once the custom mask is support is added to FA3. - if self.dcp_world_size > 1: + if ( + self.dcp_world_size > 1 + and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA" + ): assert self.reorder_batch_threshold == 1, ( "DCP not support reorder_batch_threshold > 1 now." ) @@ -699,6 +704,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the cached states. req_state.num_computed_tokens = num_computed_tokens + req_index = self.input_batch.req_id_to_index.get(req_id) if not is_last_rank: # When using PP, the scheduler sends the sampled tokens back, @@ -719,19 +725,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Some output tokens were discarded due to a sync-KV-load # failure. Align the cached state. del req_state.output_token_ids[num_output_tokens:] - - req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is not None: - old_end_idx = self.input_batch.num_tokens_no_spec[req_index] end_idx = ( self.input_batch.num_prompt_tokens[req_index] + num_output_tokens ) self.input_batch.num_tokens[req_index] = end_idx self.input_batch.num_tokens_no_spec[req_index] = end_idx - self.input_batch.is_token_ids[req_index, end_idx:old_end_idx] = ( - False - ) # Update the block IDs. if not resumed_from_preemption: @@ -740,12 +740,18 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): block_ids.extend(new_ids) else: + assert req_index is None assert new_block_ids is not None # The request is resumed from preemption. # Replace the existing block IDs with the new ones. req_state.block_ids = new_block_ids - req_index = self.input_batch.req_id_to_index.get(req_id) + if self.use_async_scheduling and num_output_tokens > 0: + # We must recover the output token ids for resumed requests in the + # async scheduling case, so that correct input_ids are obtained. + resumed_token_ids = req_data.resumed_req_token_ids[i] + assert resumed_token_ids is not None + req_state.output_token_ids = resumed_token_ids[-num_output_tokens:] if req_index is None: # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not @@ -772,7 +778,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Add spec_token_ids to token_ids_cpu. spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( - req_id, () + req_id, [] ) if spec_token_ids: num_spec_tokens = len(spec_token_ids) @@ -783,7 +789,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: ] = spec_token_ids # NOTE(woosuk): `num_tokens` here may include spec tokens. self.input_batch.num_tokens[req_index] += num_spec_tokens - self.input_batch.spec_token_ids[req_index] = spec_token_ids + + # When speculative decoding is used with structured output, + # the scheduler can drop draft tokens that do not + # conform to the schema. This can result in + # scheduler_output.scheduled_spec_decode_tokens being empty, + # even when speculative decoding is enabled. + self.input_batch.spec_token_ids[req_index] = spec_token_ids # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. @@ -857,30 +869,19 @@ def _init_mrope_positions(self, req_state: CachedRequestState): if mm_input.get("use_audio_in_video") is True: use_audio_in_video = True - if supports_mrope(self.model): - req_state.mrope_positions, req_state.mrope_position_delta = ( - self.model.get_mrope_input_positions( - req_state.prompt_token_ids, - hf_config=self.model_config.hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - ) - else: - req_state.mrope_positions, req_state.mrope_position_delta = ( - MRotaryEmbedding.get_input_positions_tensor( - req_state.prompt_token_ids, - hf_config=self.model_config.hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) + assert supports_mrope(self.get_model()), "M-RoPE support is not implemented." + + req_state.mrope_positions, req_state.mrope_position_delta = ( + self.model.get_mrope_input_positions( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, ) + ) def _extract_mm_kwargs( self, @@ -921,7 +922,7 @@ def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: def _get_cumsum_and_arange( self, num_tokens: np.ndarray, - cumsum_dtype: Optional[np.dtype] = None, + cumsum_dtype: np.dtype | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Get the cumulative sum and batched arange of the given array. # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) @@ -983,7 +984,7 @@ def _prepare_input_ids( self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) if num_commmon_tokens == 0: # No requests in common with the previous iteration - # So input_ids_cpu will have all the input ids. + # So input_ids.cpu will have all the input ids. return if indices_match and max_flattened_index == (num_commmon_tokens - 1): # Common-case optimization: the batch is unchanged @@ -997,8 +998,7 @@ def _prepare_input_ids( if self.enable_prompt_embeds: self.is_token_ids.gpu[:num_commmon_tokens] = True return - # Upload the index tensors asynchronously - # so the scatter can be non-blocking. + # Upload the index tensors asynchronously so the scatter can be non-blocking. input_ids_index_tensor = torch.tensor( flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory ).to(self.device, non_blocking=True) @@ -1018,7 +1018,7 @@ def _get_encoder_seq_lens( scheduler_output: "SchedulerOutput", kv_cache_spec: KVCacheSpec, num_reqs: int, - ) -> Optional[np.ndarray]: + ) -> np.ndarray | None: if not isinstance(kv_cache_spec, CrossAttentionSpec): return None @@ -1036,12 +1036,12 @@ def _prepare_inputs( ) -> tuple[ PerLayerAttnMetadata, torch.Tensor, - Optional[SpecDecodeMetadata], + SpecDecodeMetadata | None, np.ndarray, - Optional[CommonAttentionMetadata], + CommonAttentionMetadata | None, int, - Optional[UBatchSlices], - Optional[torch.Tensor], + UBatchSlices | None, + torch.Tensor | None, bool, ]: """ @@ -1170,13 +1170,21 @@ def _prepare_inputs( uniform_decode = ( max_num_scheduled_tokens == self.uniform_decode_query_len ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) + + # Disable DP padding when running eager to avoid excessive padding when + # running prefills. This lets us set enforce_eager on the prefiller in + # a P/D setup and still use CUDA graphs (enabled by this padding) on the + # decoder. + allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( - num_scheduled_tokens, - num_tokens_unpadded, - num_tokens_padded, - self.parallel_config, - True, - uniform_decode, + num_tokens_unpadded=num_tokens_unpadded, + parallel_config=self.parallel_config, + allow_microbatching=True, + allow_dp_padding=allow_dp_padding, + num_tokens_padded=num_tokens_padded, + uniform_decode=uniform_decode, + num_scheduled_tokens_per_request=num_scheduled_tokens, ) self.seq_lens.np[:num_reqs] = ( @@ -1333,6 +1341,9 @@ def _prepare_inputs( num_logits_indices=logits_indices.size(0), causal=True, encoder_seq_lens=encoder_seq_lens, + dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] + if self.dcp_world_size > 1 + else None, ) if self.speculative_config and spec_decode_common_attn_metadata is None: @@ -1507,6 +1518,7 @@ def _compute_cascade_attn_prefix_len( use_sliding_window=use_sliding_window, use_local_attention=use_local_attention, num_sms=self.num_sms, + dcp_world_size=self.dcp_world_size, ) return common_prefix_len if use_cascade else 0 @@ -1609,6 +1621,9 @@ def _calc_spec_decode_metadata( cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( self.device, non_blocking=True ) + cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to( + self.device, non_blocking=True + ) logits_indices = torch.from_numpy(logits_indices).to( self.device, non_blocking=True ) @@ -1624,15 +1639,15 @@ def _calc_spec_decode_metadata( draft_token_ids = self.input_ids.gpu[logits_indices] draft_token_ids = draft_token_ids[target_logits_indices + 1] - metadata = SpecDecodeMetadata( + return SpecDecodeMetadata( draft_token_ids=draft_token_ids, num_draft_tokens=num_draft_tokens.tolist(), cu_num_draft_tokens=cu_num_draft_tokens, + cu_num_sampled_tokens=cu_num_sampled_tokens, target_logits_indices=target_logits_indices, bonus_logits_indices=bonus_logits_indices, logits_indices=logits_indices, ) - return metadata def _prepare_kv_sharing_fast_prefill( self, @@ -1720,20 +1735,32 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): pin_memory=self.pin_memory, merge_by_field_config=model.merge_by_field_config, ): + curr_group_outputs = [] + + # EVS-related change. # (ekhvedchenia): Temporary hack to limit peak memory usage when - # processing multimodal data.This solves the issue with scheduler + # processing multimodal data. This solves the issue with scheduler # putting too many video samples into a single batch. Scheduler # uses pruned vision tokens count to compare it versus compute # budget which is incorrect (Either input media size or non-pruned # output vision tokens count should be considered) - curr_group_outputs = [] - - if self.is_multimodal_pruning_enabled and modality == "video": - micro_batch_size = 1 - for i in range(0, num_items, micro_batch_size): - micro_batch_mm_inputs = dict( - (k, v[i : i + micro_batch_size]) - for k, v in mm_kwargs_group.items() + # TODO(ywang96): Fix memory profiling to take EVS into account and + # remove this hack. + if ( + self.is_multimodal_pruning_enabled + and modality == "video" + and num_items > 1 + ): + for video_mm_kwargs_item in filter( + lambda item: item.modality == "video", mm_kwargs + ): + _, _, micro_batch_mm_inputs = next( + group_mm_kwargs_by_modality( + [video_mm_kwargs_item], + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, + ) ) micro_batch_outputs = model.get_multimodal_embeddings( @@ -1911,15 +1938,16 @@ def get_supported_pooling_tasks(self) -> list[PoolingTask]: supported_tasks = list(model.pooler.get_supported_tasks()) - if ( - self.scheduler_config.chunked_prefill_enabled - and "encode" in supported_tasks - ): - supported_tasks.remove("encode") + if self.scheduler_config.chunked_prefill_enabled: + if "token_embed" in supported_tasks: + supported_tasks.remove("token_embed") + if "token_classify" in supported_tasks: + supported_tasks.remove("token_classify") logger.debug_once( "Chunked prefill is not supported with " - "encode task which using ALL pooling. " + "token_embed and token_classify tasks " + "which using ALL pooling. " "Please turn off chunked prefill by " "`--no-enable-chunked-prefill` before using it." ) @@ -1994,7 +2022,8 @@ def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None: # Should be called after attention metadata creation. This just pads # the second ubatch slice out to the total number of tokens # (num_tokens + padding) - def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, num_total_tokens: int): + @staticmethod + def pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int): padded_second_ubatch_slice = slice( ubatch_slices[1].token_slice.start, num_total_tokens ) @@ -2030,7 +2059,7 @@ def _pool( ) self._sync_device() - pooler_output: list[Optional[torch.Tensor]] = [] + pooler_output: list[torch.Tensor | None] = [] for raw_output, seq_len, prompt_len in zip( raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens ): @@ -2049,7 +2078,6 @@ def _pool( def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: if ( self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH and hasattr(self, "cudagraph_batch_sizes") and self.cudagraph_batch_sizes and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1] @@ -2073,22 +2101,23 @@ def _preprocess( self, scheduler_output: "SchedulerOutput", num_input_tokens: int, # Padded - intermediate_tensors: Optional[IntermediateTensors] = None, + intermediate_tensors: IntermediateTensors | None = None, ) -> tuple[ int, - Optional[torch.Tensor], - Optional[torch.Tensor], + torch.Tensor | None, + torch.Tensor | None, torch.Tensor, - Optional[IntermediateTensors], + IntermediateTensors | None, dict[str, Any], ]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + is_first_rank = get_pp_group().is_first_rank # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order if ( self.supports_mm_inputs - and get_pp_group().is_first_rank + and is_first_rank and not self.model_config.is_encoder_decoder ): # Run the multimodal encoder if any. @@ -2113,7 +2142,7 @@ def _preprocess( **self._init_model_kwargs(num_scheduled_tokens), **self._extract_mm_kwargs(scheduler_output), } - elif self.enable_prompt_embeds and get_pp_group().is_first_rank: + elif self.enable_prompt_embeds and is_first_rank: # Get the input embeddings for the tokens that are not input embeds, # then put them into the appropriate positions. # TODO(qthequartermasterman): Since even when prompt embeds are @@ -2153,7 +2182,7 @@ def _preprocess( else: positions = self.positions.gpu[:num_input_tokens] - if get_pp_group().is_first_rank: + if is_first_rank: intermediate_tensors = None else: intermediate_tensors = self.sync_and_slice_intermediate_tensors( @@ -2178,58 +2207,42 @@ def _preprocess( def _sample( self, - logits: Optional[torch.Tensor], - spec_decode_metadata: Optional[SpecDecodeMetadata], + logits: torch.Tensor | None, + spec_decode_metadata: SpecDecodeMetadata | None, ) -> SamplerOutput: # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: - sampler_output = self.sampler( + # Update output token ids with tokens sampled in last step + # if async scheduling and required by current sampling params. + self.input_batch.update_async_output_token_ids() + return self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) - else: - # When indexing with a tensor (bonus_logits_indices), PyTorch - # creates a new tensor with separate storage from the original - # logits tensor. This means any in-place operations on bonus_logits - # won't affect the original logits tensor. - assert logits is not None - bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] - sampler_output = self.sampler( - logits=bonus_logits, - sampling_metadata=sampling_metadata, - predict_bonus_token=True, - ) - bonus_token_ids = sampler_output.sampled_token_ids - - # Just like `bonus_logits`, `target_logits` is a new tensor with - # separate storage from the original `logits` tensor. Therefore, - # it is safe to update `target_logits` in place. - target_logits = logits[spec_decode_metadata.target_logits_indices] - output_token_ids = self.rejection_sampler( - spec_decode_metadata, - None, # draft_probs - target_logits, - bonus_token_ids, - sampling_metadata, - ) - sampler_output.sampled_token_ids = output_token_ids - self._update_states_after_model_execute(output_token_ids) + sampler_output = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + logits, + sampling_metadata, + ) + self._update_states_after_model_execute(sampler_output.sampled_token_ids) return sampler_output def _bookkeeping_sync( self, scheduler_output: "SchedulerOutput", sampler_output: SamplerOutput, - logits: Optional[torch.Tensor], + logits: torch.Tensor | None, hidden_states: torch.Tensor, num_scheduled_tokens: int, + spec_decode_metadata: SpecDecodeMetadata | None, ) -> tuple[ dict[str, int], - Optional[LogprobsLists], + LogprobsLists | None, list[list[int]], - dict[str, Optional[LogprobsTensors]], + dict[str, LogprobsTensors | None], list[str], dict[str, int], list[int], @@ -2251,19 +2264,6 @@ def _bookkeeping_sync( req_ids_output_copy = self.input_batch.req_ids.copy() req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() - # NOTE: GPU -> CPU Sync happens here. - # Move as many CPU operations as possible before this sync point. - logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = ( - logprobs_tensors.tolists() if logprobs_tensors is not None else None - ) - - # Compute prompt logprobs if needed. - prompt_logprobs_dict = self._get_prompt_logprobs_dict( - hidden_states[:num_scheduled_tokens], - scheduler_output.num_scheduled_tokens, - ) - num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids invalid_req_indices = [] @@ -2292,9 +2292,6 @@ def _bookkeeping_sync( # These will be copied into input_ids in the next step # when preparing inputs. self.input_batch.prev_sampled_token_ids = sampled_token_ids - self.input_batch.prev_sampled_token_ids_invalid_indices = ( - invalid_req_indices_set - ) self.input_batch.prev_req_id_to_index = { req_id: i for i, req_id in enumerate(self.input_batch.req_ids) @@ -2307,6 +2304,10 @@ def _bookkeeping_sync( # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. req_ids = self.input_batch.req_ids + logprobs_tensors = sampler_output.logprobs_tensors + cu_num_accepted_tokens = ( + [0] if spec_decode_metadata and logprobs_tensors else None + ) for req_idx in range(num_sampled_tokens): if self.use_async_scheduling: sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None @@ -2332,6 +2333,25 @@ def _bookkeeping_sync( req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) + if cu_num_accepted_tokens is not None: + cu_num_accepted_tokens.append( + cu_num_accepted_tokens[-1] + len(sampled_ids) + ) + + # NOTE: GPU -> CPU Sync happens here. + # Move as many CPU operations as possible before this sync point. + logprobs_lists = ( + logprobs_tensors.tolists(cu_num_accepted_tokens) + if logprobs_tensors is not None + else None + ) + + # Compute prompt logprobs if needed. + prompt_logprobs_dict = self._get_prompt_logprobs_dict( + hidden_states[:num_scheduled_tokens], + scheduler_output.num_scheduled_tokens, + ) + return ( num_nans_in_logits, logprobs_lists, @@ -2359,10 +2379,10 @@ def synchronize_input_prep(self): def _model_forward( self, - input_ids: Optional[torch.Tensor] = None, - positions: Optional[torch.Tensor] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **model_kwargs: dict[str, Any], ) -> Any: """Helper method to call the model forward pass. @@ -2393,8 +2413,8 @@ def _model_forward( def execute_model( self, scheduler_output: "SchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: + intermediate_tensors: IntermediateTensors | None = None, + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: with record_function_or_nullcontext("Preprocess"): with self.synchronize_input_prep(): # Update persistent batch states. @@ -2427,12 +2447,13 @@ def execute_model( use_cascade_attn, ) = self._prepare_inputs(scheduler_output) + dp_rank = self.parallel_config.data_parallel_rank if ubatch_slices: assert num_tokens_across_dp is not None - num_input_tokens = int(num_tokens_across_dp[0].item()) + num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) elif num_tokens_across_dp is not None: - num_input_tokens = int(num_tokens_across_dp[0].item()) + num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) else: num_input_tokens = self._get_num_input_tokens( scheduler_output.total_num_scheduled_tokens @@ -2453,7 +2474,9 @@ def execute_model( num_scheduled_tokens == self.input_batch.num_reqs * max_query_len ) batch_descriptor = BatchDescriptor( - num_tokens=num_input_tokens, uniform_decode=uniform_decode + num_tokens=num_input_tokens, + uniform_decode=uniform_decode, + has_lora=len(self.input_batch.lora_id_to_lora_request) > 0, ) cudagraph_runtime_mode, batch_descriptor = ( self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn) @@ -2552,10 +2575,8 @@ def execute_model( logits = model_output_broadcast_data["logits"] # Apply structured output bitmasks if present - if scheduler_output.grammar_bitmask is not None: - apply_grammar_bitmask( - scheduler_output, self.input_batch, logits, self.device - ) + if scheduler_output.structured_output_request_ids: + apply_grammar_bitmask(scheduler_output, self.input_batch, logits) with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) @@ -2615,6 +2636,7 @@ def propose_draft_token_ids(sampled_token_ids): logits, hidden_states, num_scheduled_tokens, + spec_decode_metadata, ) if ( @@ -2643,14 +2665,23 @@ def propose_draft_token_ids(sampled_token_ids): if not self.use_async_scheduling: return output - return AsyncGPUModelRunnerOutput( + async_output = AsyncGPUModelRunnerOutput( model_runner_output=output, sampled_token_ids=sampler_output.sampled_token_ids, invalid_req_indices=invalid_req_indices, async_output_copy_stream=self.async_output_copy_stream, ) - def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + # Save ref of sampled_token_ids CPU tensor if the batch contains + # any requests with sampling params that that require output ids. + self.input_batch.set_async_sampled_token_ids( + async_output.sampled_token_ids_cpu, + async_output.async_copy_ready_event, + ) + + return async_output + + def take_draft_token_ids(self) -> DraftTokenIds | None: if self._draft_token_ids is None: return None req_ids = self.input_batch.req_ids @@ -2664,14 +2695,14 @@ def take_draft_token_ids(self) -> Optional[DraftTokenIds]: def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", - sampled_token_ids: Union[torch.Tensor, list[list[int]]], + sampled_token_ids: torch.Tensor | list[list[int]], sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, - aux_hidden_states: Optional[list[torch.Tensor]], - spec_decode_metadata: Optional[SpecDecodeMetadata], + aux_hidden_states: list[torch.Tensor] | None, + spec_decode_metadata: SpecDecodeMetadata | None, common_attn_metadata: CommonAttentionMetadata, - ) -> Union[list[list[int]], torch.Tensor]: + ) -> list[list[int]] | torch.Tensor: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": assert isinstance(sampled_token_ids, list) @@ -2819,7 +2850,11 @@ def load_model(self, eep_scale_up: bool = False) -> None: Args: eep_scale_up: the model loading is for elastic EP scale up. """ - logger.info("Starting to load model %s...", self.model_config.model) + logger.info_once( + "Starting to load model %s...", + self.model_config.model, + scope="global", + ) if eep_scale_up: from vllm.distributed.parallel_state import get_ep_group @@ -2849,7 +2884,6 @@ def load_model(self, eep_scale_up: bool = False) -> None: with DeviceMemoryProfiler() as m: time_before_load = time.perf_counter() model_loader = get_model_loader(self.load_config) - logger.info("Loading model from scratch...") self.model = model_loader.load_model( vllm_config=self.vllm_config, model_config=self.model_config ) @@ -2861,7 +2895,7 @@ def load_model(self, eep_scale_up: bool = False) -> None: logger.info("Loading drafter model...") self.drafter.load_model(self.model) if self.use_aux_hidden_state_outputs: - if not supports_eagle3(self.model): + if not supports_eagle3(self.get_model()): raise RuntimeError( "Model does not support EAGLE3 interface but " "aux_hidden_state_outputs was requested" @@ -2881,15 +2915,16 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model.set_aux_hidden_state_layers(aux_layers) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory - logger.info( + logger.info_once( "Model loading took %.4f GiB and %.6f seconds", self.model_memory_usage / GiB_bytes, time_after_load - time_before_load, + scope="local", ) prepare_communication_buffer_for_model(self.model) self.is_multimodal_pruning_enabled = ( - supports_multimodal_pruning(self.model) + supports_multimodal_pruning(self.get_model()) and self.model_config.multimodal_config.is_multimodal_pruning_enabled() ) @@ -2905,14 +2940,15 @@ def load_model(self, eep_scale_up: bool = False) -> None: ) if ( - self.vllm_config.compilation_config.level == CompilationLevel.DYNAMO_AS_IS + self.vllm_config.compilation_config.mode + == CompilationMode.STOCK_TORCH_COMPILE and supports_dynamo() ): backend = self.vllm_config.compilation_config.init_backend(self.vllm_config) - compilation_counter.dynamo_as_is_count += 1 + compilation_counter.stock_torch_compile_count += 1 self.model.compile(fullgraph=True, backend=backend) return - # for other compilation levels, cudagraph behavior is controlled by + # for other compilation modes, cudagraph behavior is controlled by # CudagraphWraper and CudagraphDispatcher of vllm. # wrap the model with full cudagraph wrapper if needed. @@ -2933,7 +2969,7 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model, self.vllm_config, CUDAGraphMode.NONE, self.device ) - def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]: + def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: """Extract Eagle3 auxiliary layer indices from speculative config. These indices specify which hidden states from the base model should @@ -2979,13 +3015,13 @@ def _get_prompt_logprobs_dict( self, hidden_states: torch.Tensor, num_scheduled_tokens: dict[str, int], - ) -> dict[str, Optional[LogprobsTensors]]: + ) -> dict[str, LogprobsTensors | None]: num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs if not num_prompt_logprobs_dict: return {} in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu - prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} + prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {} # Since prompt logprobs are a rare feature, prioritize simple, # maintainable loop over optimal performance. @@ -3079,7 +3115,7 @@ def _get_prompt_logprobs_dict( def _get_nans_in_logits( self, - logits: Optional[torch.Tensor], + logits: torch.Tensor | None, ) -> dict[str, int]: try: if logits is None: @@ -3162,7 +3198,7 @@ def _get_mm_dummy_batch( def _dummy_run( self, num_tokens: int, - cudagraph_runtime_mode: Optional[CUDAGraphMode] = None, + cudagraph_runtime_mode: CUDAGraphMode | None = None, force_attention: bool = False, uniform_decode: bool = False, allow_microbatching: bool = True, @@ -3170,6 +3206,7 @@ def _dummy_run( is_profile: bool = False, create_mixed_batch: bool = False, remove_lora: bool = True, + activate_lora: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Run a dummy forward pass to warm up/profile run or capture the @@ -3192,6 +3229,7 @@ def _dummy_run( create_mixed_batch: If True, create a mixed batch with both decode (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run + activate_lora: If False, dummy_run is performed without LoRAs. """ assert ( cudagraph_runtime_mode is None @@ -3247,21 +3285,26 @@ def _dummy_run( num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) + # Disable DP padding when running eager + allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + # We currently only microbatch if the number of tokens is # over a certain threshold. ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( - num_scheduled_tokens, - total_num_scheduled_tokens, - total_num_scheduled_tokens, - self.vllm_config.parallel_config, - allow_microbatching, - uniform_decode, + num_tokens_unpadded=total_num_scheduled_tokens, + parallel_config=self.vllm_config.parallel_config, + allow_microbatching=allow_microbatching, + allow_dp_padding=allow_dp_padding, + num_tokens_padded=total_num_scheduled_tokens, + uniform_decode=uniform_decode, + num_scheduled_tokens_per_request=num_scheduled_tokens, ) num_tokens_after_padding = num_tokens if num_tokens_across_dp is not None: - num_tokens_after_padding = int(num_tokens_across_dp[0]) + dp_rank = self.parallel_config.data_parallel_rank + num_tokens_after_padding = int(num_tokens_across_dp[dp_rank]) - attn_metadata: Optional[PerLayerAttnMetadata] = None + attn_metadata: PerLayerAttnMetadata | None = None # If force_attention is True, we always capture attention. Otherwise, # it only happens for cudagraph_runtime_mode=FULL. @@ -3307,6 +3350,9 @@ def _dummy_run( kv_cache_group_id ].slot_mapping.gpu[:num_tokens], causal=True, + dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] + if self.dcp_world_size > 1 + else None, ) for attn_group in self.attn_groups[kv_cache_group_id]: if ubatch_slices is not None: @@ -3333,7 +3379,7 @@ def _dummy_run( attn_metadata[layer_name] = attn_metadata_i with self.maybe_dummy_run_with_lora( - self.lora_config, num_scheduled_tokens, remove_lora + self.lora_config, num_scheduled_tokens, activate_lora, remove_lora ): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_after_padding <= self.max_num_tokens @@ -3380,6 +3426,7 @@ def _dummy_run( BatchDescriptor( num_tokens=num_tokens_after_padding, uniform_decode=uniform_decode, + has_lora=activate_lora and self.lora_config is not None, ) ) if not is_profile @@ -3433,7 +3480,11 @@ def _dummy_run( if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) - self.drafter.dummy_run(num_tokens) + use_cudagraphs = ( + cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE + and not self.speculative_config.enforce_eager + ) + self.drafter.dummy_run(num_tokens, use_cudagraphs=use_cudagraphs) # This is necessary to avoid blocking DP. # For dummy runs, we typically skip EPLB since we don't have any real @@ -3446,7 +3497,10 @@ def _dummy_run( self.eplb_step(is_dummy=True, is_profile=is_profile) logit_indices = np.cumsum(num_scheduled_tokens) - 1 - return hidden_states, hidden_states[logit_indices] + logit_indices_device = torch.from_numpy(logit_indices).to( + self.device, non_blocking=True + ) + return hidden_states, hidden_states[logit_indices_device] @torch.inference_mode() def _dummy_sampler_run( @@ -3507,20 +3561,16 @@ def _dummy_sampler_run( # num_tokens, logits.shape[-1], device=self.device, # dtype=logits.dtype) draft_probs = None - target_logits = torch.randn( - num_tokens, logits.shape[-1], device=self.device, dtype=logits.dtype - ) - # NOTE(woosuk): Here, we should use int32 because the sampler uses - # int32 for bonus_token_ids. If the dtype mismatches, re-compilation - # will occur at runtime. - bonus_token_ids = torch.zeros( - num_reqs, device=self.device, dtype=torch.int32 + logits = torch.randn( + num_tokens + num_reqs, + logits.shape[-1], + device=self.device, + dtype=logits.dtype, ) self.rejection_sampler( dummy_spec_decode_metadata, draft_probs, - target_logits, - bonus_token_ids, + logits, dummy_metadata, ) return sampler_output @@ -3586,8 +3636,28 @@ def _dummy_pooler_run( hidden_states: torch.Tensor, ) -> PoolerOutput: # Find the task that has the largest output for subsequent steps + supported_pooling_tasks = self.get_supported_pooling_tasks() + + if not supported_pooling_tasks: + if self.scheduler_config.chunked_prefill_enabled: + raise RuntimeError( + f"Model {self.model_config.model} does not support " + "any pooling tasks with chunked prefill enabled. " + "Please add --no-enable-chunked-prefill to your " + "config or CLI args. See " + "https://docs.vllm.ai/en/latest/models/pooling_models.html " + "to learn more." + ) + else: + raise RuntimeError( + f"Model {self.model_config.model} does not support " + "any pooling tasks. See " + "https://docs.vllm.ai/en/latest/models/pooling_models.html " + "to learn more." + ) + output_size = dict[PoolingTask, float]() - for task in self.get_supported_pooling_tasks(): + for task in supported_pooling_tasks: # Run a full batch with each task to ensure none of them OOMs output = self._dummy_pooler_run_task(hidden_states, task) output_size[task] = sum(o.nbytes for o in output) @@ -3686,8 +3756,6 @@ def capture_model(self) -> int: "ensure `cudagraph_mode` was not manually set to `NONE`" ) return 0 - else: - self.initialize_cudagraph_capture() compilation_counter.num_gpu_runner_capture_triggers += 1 @@ -3717,10 +3785,21 @@ def freeze_gc(): start_free_gpu_memory = torch.cuda.mem_get_info()[0] cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None + + if self.lora_config: + if self.compilation_config.cudagraph_specialize_lora: + lora_cases = [True, False] + else: + lora_cases = [True] + else: + lora_cases = [False] + if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: cudagraph_runtime_mode = cudagraph_mode.mixed_mode() - - compilation_cases = list(reversed(self.cudagraph_batch_sizes)) + # make sure we capture the largest batch size first + compilation_cases = list( + product(reversed(self.cudagraph_batch_sizes), lora_cases) + ) self._capture_cudagraphs( compilation_cases, cudagraph_runtime_mode=cudagraph_runtime_mode, @@ -3739,9 +3818,11 @@ def freeze_gc(): decode_cudagraph_batch_sizes = [ x for x in self.cudagraph_batch_sizes - if x <= max_num_tokens and x >= self.uniform_decode_query_len + if max_num_tokens >= x >= self.uniform_decode_query_len ] - compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes)) + compilation_cases_decode = list( + product(reversed(decode_cudagraph_batch_sizes), lora_cases) + ) self._capture_cudagraphs( compilation_cases=compilation_cases_decode, cudagraph_runtime_mode=CUDAGraphMode.FULL, @@ -3762,16 +3843,17 @@ def freeze_gc(): elapsed_time = end_time - start_time cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory # This usually takes 5~20 seconds. - logger.info( + logger.info_once( "Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30), + scope="local", ) return cuda_graph_size def _capture_cudagraphs( self, - compilation_cases: list[int], + compilation_cases: list[tuple[int, bool]], cudagraph_runtime_mode: CUDAGraphMode, uniform_decode: bool, ): @@ -3792,7 +3874,7 @@ def _capture_cudagraphs( ) # We skip EPLB here since we don't want to record dummy metrics - for num_tokens in compilation_cases: + for num_tokens, activate_lora in compilation_cases: # We currently only capture ubatched graphs when its a FULL # cudagraph, a uniform decode batch, and the number of tokens # is above the threshold. Otherwise we just capture a non-ubatched @@ -3823,6 +3905,7 @@ def _capture_cudagraphs( allow_microbatching=allow_microbatching, skip_eplb=True, remove_lora=False, + activate_lora=activate_lora, ) self._dummy_run( num_tokens, @@ -3831,6 +3914,7 @@ def _capture_cudagraphs( allow_microbatching=allow_microbatching, skip_eplb=True, remove_lora=False, + activate_lora=activate_lora, ) self.maybe_remove_all_loras(self.lora_config) @@ -3846,7 +3930,7 @@ class AttentionGroupKey(NamedTuple): def get_attn_backends_for_group( kv_cache_group_spec: KVCacheGroupSpec, - ) -> dict[AttentionGroupKey, list[str]]: + ) -> tuple[dict[AttentionGroupKey, list[str]], set[type[AttentionBackend]]]: layers = get_layers_from_vllm_config( self.vllm_config, AttentionLayerBase, kv_cache_group_spec.layer_names ) @@ -3875,7 +3959,10 @@ def get_attn_backends_for_group( attn_backend, layer_kv_cache_spec ) attn_backend_layers[key].append(layer_name) - return {attn_backends[k]: v for k, v in attn_backend_layers.items()} + return ( + {attn_backends[k]: v for k, v in attn_backend_layers.items()}, + set(group_key.attn_backend for group_key in attn_backends.values()), + ) def create_attn_groups( attn_backends_map: dict[AttentionGroupKey, list[str]], @@ -3896,14 +3983,25 @@ def create_attn_groups( attn_groups.append(attn_group) return attn_groups + attention_backend_maps = [] + attention_backend_set: set[type[AttentionBackend]] = set() for kv_cache_group_spec in kv_cache_config.kv_cache_groups: attn_backends = get_attn_backends_for_group(kv_cache_group_spec) - self.attn_groups.append(create_attn_groups(attn_backends)) + attention_backend_maps.append(attn_backends[0]) + attention_backend_set.update(attn_backends[1]) + + # Resolve cudagraph_mode before actually initialize metadata_builders + self._check_and_update_cudagraph_mode(attention_backend_set) + + for attn_backends_map in attention_backend_maps: + self.attn_groups.append(create_attn_groups(attn_backends_map)) # Calculate reorder batch threshold (if needed) self.calculate_reorder_batch_threshold() - def initialize_cudagraph_capture(self) -> None: + def _check_and_update_cudagraph_mode( + self, attention_backends: set[type[AttentionBackend]] + ) -> None: """ Resolve the cudagraph_mode when there are multiple attention backends with potential conflicting CUDA graph support. @@ -3911,13 +4009,13 @@ def initialize_cudagraph_capture(self) -> None: cudagraph_mode. """ min_cg_support = AttentionCGSupport.ALWAYS - min_cg_builder_name = None + min_cg_backend_name = None - for attn_group in self._attn_group_iterator(): - builder = attn_group.get_metadata_builder() - if builder.cudagraph_support.value < min_cg_support.value: - min_cg_support = builder.cudagraph_support - min_cg_builder_name = builder.__class__.__name__ + for attn_backend in attention_backends: + builder_cls = attn_backend.get_builder_cls() + if builder_cls.cudagraph_support.value < min_cg_support.value: + min_cg_support = builder_cls.cudagraph_support + min_cg_backend_name = attn_backend.__name__ # Flexible resolve the cudagraph mode cudagraph_mode = self.compilation_config.cudagraph_mode # check cudagraph for mixed batch is supported @@ -3927,14 +4025,14 @@ def initialize_cudagraph_capture(self) -> None: ): msg = ( f"CUDAGraphMode.{cudagraph_mode.name} is not supported " - f"with {min_cg_builder_name} backend (support: " + f"with {min_cg_backend_name} backend (support: " f"{min_cg_support})" ) if min_cg_support == AttentionCGSupport.NEVER: # if not supported any full cudagraphs, just raise it. msg += ( "; please try cudagraph_mode=PIECEWISE, and " - "make sure compilation level is piecewise" + "make sure compilation mode is VLLM_COMPILE" ) raise ValueError(msg) @@ -3958,10 +4056,10 @@ def initialize_cudagraph_capture(self) -> None: ): msg = ( f"CUDAGraphMode.{cudagraph_mode.name} is not supported " - f"with {min_cg_builder_name} backend (support: " + f"with {min_cg_backend_name} backend (support: " f"{min_cg_support})" ) - if self.compilation_config.level == CompilationLevel.PIECEWISE and ( + if self.compilation_config.mode == CompilationMode.VLLM_COMPILE and ( self.compilation_config.splitting_ops_contain_attention() or self.compilation_config.use_inductor_graph_partition ): @@ -3992,7 +4090,7 @@ def initialize_cudagraph_capture(self) -> None: msg = ( f"CUDAGraphMode.{cudagraph_mode.name} is not supported" f" with spec-decode for attention backend " - f"{min_cg_builder_name} (support: {min_cg_support})" + f"{min_cg_backend_name} (support: {min_cg_support})" ) if self.compilation_config.splitting_ops_contain_attention(): msg += "; setting cudagraph_mode=PIECEWISE" @@ -4014,14 +4112,14 @@ def initialize_cudagraph_capture(self) -> None: ): raise ValueError( f"CUDAGraphMode.{cudagraph_mode.name} is not " - f"supported with {min_cg_builder_name} backend (" + f"supported with {min_cg_backend_name} backend (" f"support:{min_cg_support}) " "; please try cudagraph_mode=PIECEWISE, " - "and make sure compilation level is piecewise" + "and make sure compilation mode is VLLM_COMPILE" ) - # Trigger cudagraph dispatching keys initialization here (after - # initializing attn backends). + # Trigger cudagraph dispatching keys initialization after + # resolved cudagraph mode. self.cudagraph_dispatcher.initialize_cudagraph_keys( self.compilation_config.cudagraph_mode, self.uniform_decode_query_len ) @@ -4049,6 +4147,86 @@ def calculate_reorder_batch_threshold(self) -> None: else: self.reorder_batch_threshold = reorder_batch_threshold_i + def _find_compatible_block_sizes( + self, + kv_manager_block_size: int, + backend_cls: type[AttentionBackend], + return_all: bool = False, + ) -> list[int]: + """ + Find compatible block sizes for a backend. + + Args: + kv_manager_block_size: Physical block size of KV cache + backend_cls: Attention backend class + return_all: Return all compatible sizes if True, max size if False + + Returns: + Compatible block size(s) based on return_all parameter + + Raises: + ValueError: If no compatible block size found + """ + supported_block_size = backend_cls.get_supported_kernel_block_size() + compatible_sizes = [] + + for block_size in supported_block_size: + if isinstance(block_size, int): + if kv_manager_block_size % block_size == 0: + compatible_sizes.append(block_size) + elif ( + isinstance(block_size, MultipleOf) + and kv_manager_block_size % block_size.base == 0 + ): + compatible_sizes.append(kv_manager_block_size) + + if not compatible_sizes: + raise ValueError(f"No compatible block size for {kv_manager_block_size}") + + return compatible_sizes if return_all else [max(compatible_sizes)] + + def _select_common_block_size( + self, kv_manager_block_size: int, attn_groups: list[AttentionGroup] + ) -> int: + """ + Select common block size for all backends. + + Args: + kv_manager_block_size: Block size of KV cache + attn_groups: List of attention groups + + Returns: + Block size supported by all backends, + prioritizing cache_config.block_size + + Raises: + ValueError: If no common block size found + """ + all_backend_supports = [] + + for attn_group in attn_groups: + compatible_sizes = self._find_compatible_block_sizes( + kv_manager_block_size, attn_group.backend, return_all=True + ) + supported_sizes = sorted(list(set(compatible_sizes)), reverse=True) + all_backend_supports.append(set(supported_sizes)) + + common_supported_sizes = set.intersection(*all_backend_supports) + + if not common_supported_sizes: + error_msg = f"No common block size for {kv_manager_block_size}. " + for i, attn_group in enumerate(attn_groups): + supported = all_backend_supports[i] + error_msg += ( + f"Backend {attn_group.backend} supports: {sorted(supported)}. " + ) + raise ValueError(error_msg) + + if self.cache_config.block_size in common_supported_sizes: + return self.cache_config.block_size + + return max(common_supported_sizes) + def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: """ Re-initialize the input batch if the block sizes are different from @@ -4061,8 +4239,15 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: block_sizes = [ kv_cache_group.kv_cache_spec.block_size for kv_cache_group in kv_cache_config.kv_cache_groups + if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec) ] - if block_sizes != [self.cache_config.block_size]: + + # Generate kernel_block_sizes that matches each block_size + kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config) + + if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [ + self.cache_config.block_size + ]: assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 @@ -4076,8 +4261,10 @@ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=self.input_batch.logitsprocs, + logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, is_pooling_model=self.is_pooling_model, num_speculative_tokens=( self.vllm_config.speculative_config.num_speculative_tokens @@ -4127,6 +4314,51 @@ def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]: for attn_groups in self.attn_groups: yield from attn_groups + def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[int]: + """ + Generate kernel_block_sizes that matches each block_size. + + For attention backends that support virtual block splitting, + use the supported block sizes from the backend. + For other backends (like Mamba), use the same block size (no splitting). + + Args: + kv_cache_config: The KV cache configuration. + + Returns: + list[int]: List of kernel block sizes for each cache group. + """ + kernel_block_sizes = [] + for kv_cache_group_id, kv_cache_group in enumerate( + kv_cache_config.kv_cache_groups + ): + kv_cache_spec = kv_cache_group.kv_cache_spec + if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): + # All layers in the UniformTypeKVCacheSpecs have the same type, + # Pick an arbitrary one to dispatch. + kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values())) + if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): + continue + elif isinstance(kv_cache_spec, AttentionSpec): + # This is an attention backend that supports virtual + # block splitting. Get the supported block sizes from + # all backends in the group. + attn_groups = self.attn_groups[kv_cache_group_id] + kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size + selected_kernel_size = self._select_common_block_size( + kv_manager_block_size, attn_groups + ) + kernel_block_sizes.append(selected_kernel_size) + elif isinstance(kv_cache_spec, MambaSpec): + # This is likely Mamba or other non-attention cache, + # no splitting. + kernel_block_sizes.append(kv_cache_spec.block_size) + else: + raise NotImplementedError( + f"unknown kv cache spec {kv_cache_group.kv_cache_spec}" + ) + return kernel_block_sizes + def _reshape_kv_cache_tensors( self, kv_cache_config: KVCacheConfig, @@ -4156,16 +4388,24 @@ def _reshape_kv_cache_tensors( num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, AttentionSpec): has_attn = True + kv_manager_block_size = kv_cache_spec.block_size + kernel_size_list = self._find_compatible_block_sizes( + kv_manager_block_size, attn_backend, return_all=False + ) + kernel_size = kernel_size_list[0] + num_blocks_per_kv_block = kv_manager_block_size // kernel_size + kernel_num_blocks = num_blocks * num_blocks_per_kv_block + kv_cache_shape = attn_backend.get_kv_cache_shape( - num_blocks, - kv_cache_spec.block_size, + kernel_num_blocks, + kernel_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, cache_dtype_str=self.cache_config.cache_dtype, ) dtype = kv_cache_spec.dtype try: - kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() # noqa: E501 assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): kv_cache_stride_order = tuple(range(len(kv_cache_shape))) @@ -4319,10 +4559,11 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config - self.may_reinitialize_input_batch(kv_cache_config) self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) + # Reinitialize need to after initialize_attn_backend + self.may_reinitialize_input_batch(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) if self.speculative_config and self.speculative_config.use_eagle(): @@ -4384,13 +4625,12 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ - block_size = self.vllm_config.cache_config.block_size - use_mla = self.vllm_config.model_config.use_mla - cache_dtype_str = self.vllm_config.cache_config.cache_dtype kv_cache_spec: dict[str, KVCacheSpec] = {} - attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) for layer_name, attn_module in attn_layers.items(): - if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: + if isinstance(attn_module, Attention) and ( + kv_tgt_layer := attn_module.kv_sharing_target_layer_name + ): # The layer doesn't need its own KV cache and will use that of # the target layer. We skip creating a KVCacheSpec for it, so # that KV cache management logic will act as this layer does @@ -4400,91 +4640,9 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: # or enable more requests to be processed simultaneously. self.shared_kv_cache_layers[layer_name] = kv_tgt_layer continue - - # TODO(lucas): move the attention specs into the model layers like - # the attention backends - if attn_module.attn_type == AttentionType.DECODER: - if attn_module.sliding_window is not None: - assert not use_mla, "MLA is not supported for slidingwindow" - kv_cache_spec[layer_name] = SlidingWindowSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window, - ) - elif use_mla: - kv_cache_spec[layer_name] = MLAAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - cache_dtype_str=cache_dtype_str, - ) - elif self.attention_chunk_size is not None and isinstance( - attn_module, ChunkedLocalAttention - ): - kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - attention_chunk_size=self.attention_chunk_size, - ) - else: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - ) - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - kv_cache_spec[layer_name] = CrossAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - ) - elif attn_module.attn_type in ( - AttentionType.ENCODER, - AttentionType.ENCODER_ONLY, - ): - # encoder-only attention does not need KV cache. - continue - else: - raise ValueError(f"Unknown attention type: {attn_module.attn_type}") - - mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) - if len(mamba_layers) > 0: - if ( - self.vllm_config.speculative_config is not None - and self.vllm_config.model_config.hf_config.model_type - not in ["qwen3_next"] - ): - raise NotImplementedError( - "Mamba with speculative decoding is not supported yet." - ) - mamba_block_size = self.vllm_config.cache_config.mamba_block_size - page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded - - for layer_name, mamba_module in mamba_layers.items(): - kv_cache_spec[layer_name] = MambaSpec( - shapes=mamba_module.get_state_shape(), - dtypes=mamba_module.get_state_dtype(), - block_size=mamba_block_size, - page_size_padded=page_size_padded, - mamba_type=mamba_module.mamba_type, - num_speculative_blocks=( - self.speculative_config.num_speculative_tokens - if self.speculative_config - else 0 - ), - ) - ds_indexer_layers = get_layers_from_vllm_config( - self.vllm_config, DeepseekV32IndexerCache - ) - for layer_name, ds_indexer_module in ds_indexer_layers.items(): - kv_cache_spec[layer_name] = ds_indexer_module.get_kv_cache_spec() + # Skip modules that don't need KV cache (eg encoder-only attention) + if spec := attn_module.get_kv_cache_spec(self.vllm_config): + kv_cache_spec[layer_name] = spec return kv_cache_spec diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index fb63fe8d2543..9de123263755 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import threading +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable, Optional +from typing import Any import torch @@ -21,7 +22,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import has_deep_gemm +from vllm.utils.import_utils import has_deep_gemm from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts logger = init_logger(__name__) @@ -32,8 +33,8 @@ class UbatchMetadata: context: UBatchContext input_ids: torch.Tensor positions: torch.Tensor - inputs_embeds: Optional[torch.Tensor] - intermediate_tensors: Optional[IntermediateTensors] + inputs_embeds: torch.Tensor | None + intermediate_tensors: IntermediateTensors | None num_tokens: int @@ -41,7 +42,7 @@ class UbatchMetadata: class CUDAGraphMetaData: cudagraph: torch.cuda.CUDAGraph ubatch_metadata: UbatchMetadata - outputs: Optional[Any] = None + outputs: Any | None = None class SMControlContextManager: diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 70f7c1d45b5f..3ed9cab42a14 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -6,7 +6,7 @@ import gc import os from contextlib import AbstractContextManager, nullcontext -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any import torch import torch.distributed @@ -20,7 +20,10 @@ set_custom_all_reduce, ) from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized -from vllm.distributed.parallel_state import get_pp_group, get_tp_group +from vllm.distributed.parallel_state import ( + get_pp_group, + get_tp_group, +) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed @@ -28,7 +31,8 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask -from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling +from vllm.utils.mem_constants import GiB_bytes +from vllm.utils.mem_utils import MemorySnapshot, memory_profiling from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ( @@ -79,6 +83,7 @@ def __init__( # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace if envs.VLLM_TORCH_PROFILER_DIR: torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" logger.info( "Profiling enabled. Traces will be saved to: %s", torch_profiler_trace_dir, @@ -101,7 +106,7 @@ def __init__( with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True + torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True ), ) else: @@ -131,7 +136,7 @@ def sleep(self, level: int = 1) -> None: used_bytes / GiB_bytes, ) - def wake_up(self, tags: Optional[list[str]] = None) -> None: + def wake_up(self, tags: list[str] | None = None) -> None: from vllm.device_allocator.cumem import CuMemAllocator allocator = CuMemAllocator.get_instance() @@ -311,9 +316,10 @@ def determine_available_memory(self) -> int: GiB(free_gpu_memory - unrequested_memory), ) logger.debug(profile_result) - logger.info( + logger.info_once( "Available KV cache memory: %.2f GiB", GiB(self.available_kv_cache_memory_bytes), + scope="local", ) gc.collect() @@ -325,6 +331,15 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" + # Init kv cache connector here, because it requires + # `kv_cache_config`. + # NOTE(Kuntai): This need to be done before `initialize_kv_cache`, + # because `initialize_kv_cache` will inject kv cache groups not + # related to kv cache connector (e.g. kv cache sharing layers). + connector_vllm_config = copy.copy(self.vllm_config) + connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config) + ensure_kv_transfer_initialized(connector_vllm_config) + if self.vllm_config.model_config.enable_sleep_mode: from vllm.device_allocator.cumem import CuMemAllocator @@ -442,6 +457,9 @@ def compile_or_warm_up_model(self) -> None: # the model initialization and profiling. set_random_seed(self.model_config.seed) + def reset_mm_cache(self) -> None: + self.model_runner.reset_mm_cache() + def get_model(self) -> nn.Module: return self.model_runner.get_model() @@ -452,7 +470,7 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def execute_model( self, scheduler_output: "SchedulerOutput", - ) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]: + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None: intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -500,7 +518,7 @@ def execute_model( output.kv_connector_output = kv_connector_output return output - def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + def take_draft_token_ids(self) -> DraftTokenIds | None: return self.model_runner.take_draft_token_ids() def profile(self, is_start: bool = True): @@ -561,7 +579,7 @@ def _eplb_after_scale_up( self, old_ep_size: int, new_ep_size: int, - global_expert_load: Optional[torch.Tensor], + global_expert_load: torch.Tensor | None, ) -> None: from vllm.distributed.parallel_state import get_ep_group @@ -607,7 +625,7 @@ def _reconfigure_parallel_config( def _reconfigure_moe( self, old_ep_size: int, new_ep_size: int - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: """ Reconfigure MoE modules with provided reconfig_request @@ -726,8 +744,8 @@ def reinitialize_distributed( def save_sharded_state( self, path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, + pattern: str | None = None, + max_size: int | None = None, ) -> None: from vllm.model_executor.model_loader import ShardedStateLoader @@ -754,12 +772,15 @@ def shutdown(self) -> None: def init_worker_distributed_environment( vllm_config: VllmConfig, rank: int, - distributed_init_method: Optional[str] = None, + distributed_init_method: str | None = None, local_rank: int = -1, backend: str = "nccl", ) -> None: """Initialize the distributed environment.""" parallel_config = vllm_config.parallel_config + from vllm.model_executor.layers.batch_invariant import init_batch_invariance + + init_batch_invariance() set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment( @@ -771,5 +792,3 @@ def init_worker_distributed_environment( parallel_config.pipeline_parallel_size, parallel_config.decode_context_parallel_size, ) - - ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 473982bebb12..db037a9fccd5 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -9,7 +9,6 @@ from contextlib import AbstractContextManager, contextmanager, nullcontext from typing import ( TYPE_CHECKING, # noqa: UP035 - Optional, ) from vllm.config import VllmConfig @@ -65,7 +64,7 @@ def maybe_wait_for_kv_save() -> None: @staticmethod def get_finished_kv_transfers( scheduler_output: "SchedulerOutput", - ) -> tuple[Optional[set[str]], Optional[set[str]]]: + ) -> tuple[set[str] | None, set[str] | None]: if has_kv_transfer_group(): return get_kv_transfer_group().get_finished( scheduler_output.finished_req_ids @@ -95,7 +94,7 @@ def kv_connector_no_forward( @staticmethod def maybe_get_kv_connector_output( scheduler_output: "SchedulerOutput", - ) -> AbstractContextManager[Optional[KVConnectorOutput]]: + ) -> AbstractContextManager[KVConnectorOutput | None]: return ( KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output) if has_kv_transfer_group() @@ -139,7 +138,7 @@ def _get_kv_connector_output( kv_connector.clear_connector_metadata() @staticmethod - def get_kv_connector_stats() -> Optional[KVConnectorStats]: + def get_kv_connector_stats() -> KVConnectorStats | None: if has_kv_transfer_group(): return get_kv_transfer_group().get_kv_connector_stats() return None diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index e7358c4271ce..372bc0a05673 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -5,7 +5,6 @@ """ from contextlib import contextmanager -from typing import Optional, Union import numpy as np import torch @@ -21,15 +20,13 @@ from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch -InputBatch = Union[TPUInputBatch, GPUInputBatch] +InputBatch = TPUInputBatch | GPUInputBatch logger = init_logger(__name__) # Defined as a mixin for GPUModelRunner class LoRAModelRunnerMixin: - LORA_WARMUP_RANK = 8 - def load_lora_model( self, model: nn.Module, vllm_config: VllmConfig, device: torch.device ) -> nn.Module: @@ -87,7 +84,7 @@ def set_active_loras( @contextmanager def maybe_setup_dummy_loras( - self, lora_config: Optional[LoRAConfig], remove_lora: bool = True + self, lora_config: LoRAConfig | None, remove_lora: bool = True ): if lora_config is None: yield @@ -96,7 +93,9 @@ def maybe_setup_dummy_loras( assert self.lora_manager is not None, "LoRA is not enabled" num_loras = lora_config.max_loras - + lora_warmup_rank = ( + lora_config.max_lora_rank if lora_config.max_lora_rank < 8 else 8 + ) # Make dummy lora requests lora_requests: set[LoRARequest] = { LoRARequest( @@ -111,7 +110,7 @@ def maybe_setup_dummy_loras( # Add the dummy LoRAs here so _set_active_loras doesn't try to # load from disk. for lr in lora_requests: - self.lora_manager.add_dummy_lora(lr, rank=self.LORA_WARMUP_RANK) + self.lora_manager.add_dummy_lora(lr, rank=lora_warmup_rank) yield @@ -121,7 +120,10 @@ def maybe_setup_dummy_loras( @contextmanager def maybe_select_dummy_loras( - self, lora_config: Optional[LoRAConfig], num_scheduled_tokens: np.ndarray + self, + lora_config: LoRAConfig | None, + num_scheduled_tokens: np.ndarray, + activate_lora: bool = True, ): if lora_config is None: yield @@ -134,7 +136,12 @@ def maybe_select_dummy_loras( # Make prompt lora mapping # Assign LoRA IDs cyclically to simulate a worst-case scenario. - prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % num_loras) + 1 + if activate_lora: + prompt_lora_mapping = ( + np.arange(num_reqs, dtype=np.int32) % num_loras + ) + 1 + else: + prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32) # Make token lora mapping token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens) @@ -158,17 +165,20 @@ def maybe_select_dummy_loras( @contextmanager def maybe_dummy_run_with_lora( self, - lora_config: Optional[LoRAConfig], + lora_config: LoRAConfig | None, num_scheduled_tokens: np.ndarray, + activate_lora: bool = True, remove_lora: bool = True, ): with ( self.maybe_setup_dummy_loras(lora_config, remove_lora), - self.maybe_select_dummy_loras(lora_config, num_scheduled_tokens), + self.maybe_select_dummy_loras( + lora_config, num_scheduled_tokens, activate_lora + ), ): yield - def maybe_remove_all_loras(self, lora_config: Optional[LoRAConfig]): + def maybe_remove_all_loras(self, lora_config: LoRAConfig | None): if lora_config is None: return self.lora_manager.remove_all_adapters() diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index 34fed8f96467..74e8225b2f4b 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -2,14 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Datastructures defining a TPU input batch -from typing import Optional, cast +from typing import cast import numpy as np import torch from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingType -from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values +from vllm.utils import length_from_prompt_token_ids_or_embeds +from vllm.utils.collection_utils import swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.worker.block_table import MultiGroupBlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState @@ -27,6 +28,7 @@ def __init__( pin_memory: bool, vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group + kernel_block_sizes: list[int], ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -35,7 +37,7 @@ def __init__( self.pin_memory = pin_memory self.vocab_size = vocab_size - self._req_ids: list[Optional[str]] = [] + self._req_ids: list[str | None] = [] self.req_id_to_index: dict[str, int] = {} # TODO(woosuk): This buffer could be too large if max_model_len is big. @@ -68,6 +70,7 @@ def __init__( pin_memory=pin_memory, device=device, block_sizes=block_sizes, + kernel_block_sizes=kernel_block_sizes, ) # Sampling-related. @@ -153,17 +156,17 @@ def __init__( # To accumulate prompt logprobs tensor chunks across prefill steps. self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {} - self.logit_bias: list[Optional[dict[int, float]]] = [None] * max_num_reqs + self.logit_bias: list[dict[int, float] | None] = [None] * max_num_reqs self.has_allowed_token_ids: set[str] = set() # NOTE(lufang): In the mask tensor, if the corresponding token allowed, # the value is False. Since we use masked_fill_ to set -inf. - self.allowed_token_ids_mask: Optional[torch.Tensor] = None - self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None + self.allowed_token_ids_mask: torch.Tensor | None = None + self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None # req_index -> bad_words_token_ids self.bad_words_token_ids: dict[int, list[list[int]]] = {} - self.req_output_token_ids: list[Optional[list[int]]] = [] + self.req_output_token_ids: list[list[int] | None] = [] @property def req_ids(self) -> list[str]: @@ -174,7 +177,7 @@ def req_ids(self) -> list[str]: def add_request( self, request: "CachedRequestState", - req_index: Optional[int] = None, + req_index: int | None = None, ) -> None: if req_index is None: req_index = self.num_reqs @@ -212,8 +215,8 @@ def add_request( sampling_params = request.sampling_params assert sampling_params is not None, "pooling requests not supported yet" if sampling_params.sampling_type == SamplingType.GREEDY: - # Avoid later division by zero. - self.temperature_cpu[req_index] = -1.0 + # Should avoid division by zero later when apply_temperature. + self.temperature_cpu[req_index] = 0.0 self.greedy_reqs.add(req_id) else: self.temperature_cpu[req_index] = sampling_params.temperature @@ -294,7 +297,7 @@ def add_request( # No LoRA self.request_lora_mapping[req_index] = 0 - def remove_request(self, req_id: str) -> Optional[int]: + def remove_request(self, req_id: str) -> int | None: """This method must always be followed by a call to condense().""" req_index = self.req_id_to_index.pop(req_id, None) @@ -578,7 +581,7 @@ def no_penalties(self) -> bool: ) @property - def max_num_logprobs(self) -> Optional[int]: + def max_num_logprobs(self) -> int | None: return max(self.num_logprobs.values()) if self.num_logprobs else None @property diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 1d53fa954a7f..18b857a64136 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -3,7 +3,7 @@ import bisect import gc import time -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from unittest.mock import patch import numpy as np @@ -19,6 +19,7 @@ import vllm.envs as envs from vllm.attention import Attention from vllm.attention.backends.abstract import AttentionType +from vllm.attention.layer import MLAAttention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import ( @@ -32,6 +33,7 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.tpu import TPUModelLoader from vllm.model_executor.models.interfaces import ( @@ -51,7 +53,8 @@ from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask -from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available, prev_power_of_2 +from vllm.utils import LayerBlockType, cdiv, prev_power_of_2 +from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backends.pallas import ( TPU_STR_DTYPE_TO_TORCH_DTYPE, PallasAttentionBackend, @@ -63,6 +66,7 @@ FullAttentionSpec, KVCacheConfig, KVCacheSpec, + MLAAttentionSpec, SlidingWindowSpec, ) from vllm.v1.outputs import ( @@ -137,7 +141,7 @@ def __init__( self, vllm_config: VllmConfig, device: torch.device, - original_parallel_config: Optional[ParallelConfig] = None, + original_parallel_config: ParallelConfig | None = None, ): self.vllm_config = vllm_config self.model_config = vllm_config.model_config @@ -256,6 +260,7 @@ def __init__( pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.block_size], + kernel_block_sizes=[self.cache_config.block_size], ) # Cached torch/numpy tensor @@ -367,6 +372,10 @@ def __init__( else: self.sample_from_logits_func = self.sample_from_logits + def reset_mm_cache(self) -> None: + if self.mm_budget: + self.mm_budget.reset_cache() + def _update_num_xla_graphs(self, case_str): check_comp = self.check_recompilation and not self.enforce_eager if not check_comp: @@ -561,52 +570,71 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ - layers = get_layers_from_vllm_config(self.vllm_config, Attention) + layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) block_size = self.vllm_config.cache_config.block_size + cache_dtype_str = self.vllm_config.cache_config.cache_dtype + kv_cache_spec: dict[str, KVCacheSpec] = {} for layer_name, attn_module in layers.items(): - if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: - # The layer doesn't need its own KV cache and will use that of - # the target layer. We skip creating a KVCacheSpec for it, so - # that KV cache management logic will act as this layer does - # not exist, and doesn't allocate KV cache for the layer. This - # enables the memory saving of cross-layer kv sharing, allowing - # a given amount of memory to accommodate longer context lengths - # or enable more requests to be processed simultaneously. - self.shared_kv_cache_layers[layer_name] = kv_tgt_layer - continue + # Classic Attention path + if isinstance(attn_module, Attention): + if ( + kv_tgt_layer := attn_module.kv_sharing_target_layer_name + ) is not None: + # The layer doesn't need its own KV cache and will use that of + # the target layer. We skip creating a KVCacheSpec for it, so + # that KV cache management logic will act as this layer does + # not exist, and doesn't allocate KV cache for the layer. This + # enables the memory saving of cross-layer kv sharing, allowing + # a given amount of memory to accommodate longer context lengths + # or enable more requests to be processed simultaneously. + self.shared_kv_cache_layers[layer_name] = kv_tgt_layer + continue - if attn_module.attn_type == AttentionType.DECODER: - if isinstance(attn_module, ChunkedLocalAttention): - logger.warning_once( - "Using irope in Pallas is not supported yet, it " - "will fall back to global attention for long context." - ) - if attn_module.sliding_window is not None: - kv_cache_spec[layer_name] = SlidingWindowSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window, - ) + if attn_module.attn_type == AttentionType.DECODER: + if isinstance(attn_module, ChunkedLocalAttention): + logger.warning_once( + "Using irope in Pallas is not supported yet, it " + "will fall back to global attention for long context." + ) + if attn_module.sliding_window is not None: + kv_cache_spec[layer_name] = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + sliding_window=attn_module.sliding_window, + ) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + ) + elif attn_module.attn_type in ( + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + ): + # encoder-only attention does not need KV cache. + continue + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError else: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - ) - elif attn_module.attn_type in ( - AttentionType.ENCODER, - AttentionType.ENCODER_ONLY, - ): - # encoder-only attention does not need KV cache. - continue - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - raise NotImplementedError + raise ValueError(f"Unknown attention type: {attn_module.attn_type}") + # MLAAttention path + elif isinstance(attn_module, MLAAttention): + if layer_name in kv_cache_spec: + continue + kv_cache_spec[layer_name] = MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + cache_dtype_str=cache_dtype_str, + ) else: - raise ValueError(f"Unknown attention type: {attn_module.attn_type}") + continue return kv_cache_spec @@ -1023,7 +1051,7 @@ def _gather_mm_embeddings( def _get_model_inputs( self, input_ids: torch.Tensor, - mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]], + mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None, ): if self.supports_mm_inputs: mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) @@ -1049,7 +1077,7 @@ def _get_model_inputs( def execute_model( self, scheduler_output: "SchedulerOutput", - intermediate_tensors: Optional[IntermediateTensors] = None, + intermediate_tensors: IntermediateTensors | None = None, ) -> ModelRunnerOutput: # Update cached state self._update_states(scheduler_output) @@ -1193,7 +1221,7 @@ def concat_lists(input_lists): ), "req_ids contains None" req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs]) - prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} + prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {} for req_id in self.input_batch.req_ids[:num_reqs]: prompt_logprobs_dict[req_id] = None @@ -1766,6 +1794,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: block_sizes=[ kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size ], + kernel_block_sizes=[ + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + ], ) # Verify dtype compatibility between block_table_cpu and input_batch assert ( @@ -1933,12 +1964,8 @@ def prepare_structured_decoding_input( self.grammar_bitmask_cpu.zero_() self.require_structured_out_cpu.zero_() - sorted_struct_requests = sorted( - scheduler_output.structured_output_request_ids.items(), - key=lambda item: item[1], - ) cumulative_mask_idx = 0 - for req_id, _ in sorted_struct_requests: + for req_id in scheduler_output.structured_output_request_ids: if req_id not in self.input_batch.req_id_to_index: continue batch_index = self.input_batch.req_id_to_index[req_id] @@ -2097,13 +2124,12 @@ def _tpu_set_lora( index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, + embeddings_tensor: torch.Tensor | None, ): # TODO: The integer index leads to a recompilation, but converting it # to a tensor doesn't seem to work anymore. This might be fixed with a # later release of torch_xla. - self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias) + self._original_set_lora(index, lora_a, lora_b, embeddings_tensor) torch_xla.sync(wait=False) def _tpu_reset_lora(self, index: int): diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 861d7ae737ee..fae1f8e37b0c 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -3,10 +3,10 @@ """A TPU worker class.""" import os -from typing import Any, Callable, Optional, TypeVar +from collections.abc import Callable +from typing import Any, TypeVar import torch -import torch.distributed import torch.nn as nn import vllm.envs as envs @@ -25,7 +25,8 @@ from vllm.platforms import current_platform from vllm.platforms.tpu import USE_TPU_INFERENCE from vllm.tasks import SupportedTask -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.utils import cdiv +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput @@ -182,8 +183,8 @@ def determine_available_memory(self) -> int: if isinstance(layer_spec, AttentionSpec): dtype = layer_spec.dtype - # Use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. + # Use an empty tensor instead of `None` to force Dynamo to pass + # it by reference, rather by specializing on the value `None`. tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device) kv_caches[layer_name] = tpu_kv_cache else: @@ -257,7 +258,7 @@ def determine_available_memory(self) -> int: def execute_model( self, scheduler_output: "SchedulerOutput", - ) -> Optional[ModelRunnerOutput]: + ) -> ModelRunnerOutput | None: output = self.model_runner.execute_model(scheduler_output) # every worker's output is needed when kv_transfer_group is set up return output if self.is_driver_worker or has_kv_transfer_group() else None @@ -293,6 +294,9 @@ def compile_or_warm_up_model(self) -> None: # the model initialization and profiling. set_random_seed(self.model_config.seed) + def reset_mm_cache(self) -> None: + self.model_runner.reset_mm_cache() + def get_model(self) -> nn.Module: return self.model_runner.get_model() @@ -314,7 +318,7 @@ def _init_tpu_worker_distributed_environment( self, vllm_config: VllmConfig, rank: int, - distributed_init_method: Optional[str] = None, + distributed_init_method: str | None = None, local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index ef22977e094b..33a1921d2d98 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass +from typing import TypeAlias import numpy as np -from typing_extensions import TypeAlias from vllm.config import ParallelConfig diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 867ce2b93036..6edcb7848638 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -7,7 +7,7 @@ from vllm import forward_context from vllm.forward_context import ForwardContext -from vllm.utils import current_stream +from vllm.utils.torch_utils import current_stream _THREAD_ID_TO_CONTEXT: dict = {} _CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [None, None] diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index c3d16827f10e..92baf0cb7136 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import torch @@ -42,10 +42,10 @@ def __init__( self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache) - max_tokens_by_modality = ( - mm_registry.get_max_tokens_per_item_by_nonzero_modality( - model_config, cache=cache - ) + max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality( + model_config, + cache=cache, + profiler_limits=self.mm_limits, ) encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( @@ -126,6 +126,10 @@ def get_max_items( return max_items_per_prompt, max_items_per_batch + def reset_cache(self) -> None: + if self.cache is not None: + self.cache.clear_cache() + @dataclass class AttentionGroup: @@ -189,7 +193,7 @@ def sanity_check_mm_encoder_outputs( def scatter_mm_placeholders( embeds: torch.Tensor, - is_embed: Optional[torch.Tensor], + is_embed: torch.Tensor | None, ) -> torch.Tensor: """ Scatter the multimodal embeddings into a contiguous tensor that represents @@ -217,7 +221,7 @@ def scatter_mm_placeholders( def gather_mm_placeholders( placeholders: torch.Tensor, - is_embed: Optional[torch.Tensor], + is_embed: torch.Tensor | None, ) -> torch.Tensor: """ Reconstructs the embeddings from the placeholder tokens. @@ -234,7 +238,7 @@ def gather_mm_placeholders( def add_kv_sharing_layers_to_kv_cache_groups( shared_kv_cache_layers: dict[str, str], kv_cache_groups: list[KVCacheGroupSpec], - runner_only_attn_layers: Optional[set[str]] = None, + runner_only_attn_layers: set[str] | None = None, ) -> None: """ Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches` @@ -266,7 +270,7 @@ def bind_kv_cache( kv_caches: dict[str, torch.Tensor], forward_context: dict[str, "Attention"], runner_kv_caches: list[torch.Tensor], - num_attn_module: Optional[int] = 1, + num_attn_module: int | None = 1, ) -> None: """ Bind the allocated KV cache to both ModelRunner and forward context so @@ -324,8 +328,12 @@ def is_residual_scattered_for_sp( """Check if the residual tensor is scattered for sequence parallelism. The residual tensor is scattered across tensor parallel ranks when sequence - parallelism and tensor parallelism is enabled, and the number of - input tokens is one of the compilation sizes. + parallelism and tensor parallelism is enabled. + + This follows the same logic as SequenceParallelismPass.is_applicable(): + - In full-graph compilation mode (no splitting ops or using inductor graph + partition), SP is always applied + - Otherwise, SP is only applied for specific shapes in compile_sizes """ if not vllm_config.compilation_config.pass_config.enable_sequence_parallelism: return False @@ -339,5 +347,10 @@ def is_residual_scattered_for_sp( # to be a multiple of tensor_parallel_size (tp) earlier. assert num_input_tokens % tp == 0 - # Currently, SP is only enabled for static size fx graphs. + if ( + not vllm_config.compilation_config.splitting_ops + or vllm_config.compilation_config.use_inductor_graph_partition + ): + return True + return num_input_tokens in vllm_config.compilation_config.compile_sizes diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index dc9bb3910fbc..d912589ef73a 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -1,10 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from __future__ import annotations - import os -from typing import Any, Callable, TypeVar, Union +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, TypeVar import torch import torch.nn as nn @@ -12,16 +11,23 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import worker_receiver_cache_from_config from vllm.utils import ( enable_trace_function_call_for_thread, - resolve_obj_by_qualname, run_method, - update_environment_variables, warn_for_unimplemented_methods, ) +from vllm.utils.import_utils import resolve_obj_by_qualname +from vllm.utils.system_utils import update_environment_variables from vllm.v1.kv_cache_interface import KVCacheSpec -from vllm.v1.outputs import SamplerOutput + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.outputs import ModelRunnerOutput +else: + SchedulerOutput = object + ModelRunnerOutput = object logger = init_logger(__name__) @@ -103,6 +109,11 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Initialize the KV cache with the given size in blocks.""" raise NotImplementedError + def reset_mm_cache(self) -> None: + reset_fn = getattr(self.model_runner, "reset_mm_cache", None) + if callable(reset_fn): + reset_fn() + def get_model(self) -> nn.Module: raise NotImplementedError @@ -114,35 +125,7 @@ def load_model(self) -> None: """Load model onto target device.""" raise NotImplementedError - def execute_model( - self, execute_model_req: ExecuteModelRequest | None = None - ) -> list[SamplerOutput] | None: - raise NotImplementedError - - def start_worker_execution_loop(self) -> None: - """Execute model loop in parallel worker. - - You can stop the loop by executing a driver worker with an empty output. - See `stop_remote_worker_execution_loop` for more details. - """ - with self.current_platform.inference_mode(): - while True: - output = self.execute_model(execute_model_req=None) - if output is None: - return None - - def determine_num_available_blocks(self) -> tuple[int, int]: - """Determine the number of available blocks for the GPU KV cache and - swappable CPU KV cache. - - The implementation may run profiling or other heuristics to determine - the size of caches. - - Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks - are blocks that are "active" on the device and can be appended to. - num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be - appended to. - """ + def execute_model(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput: raise NotImplementedError def get_cache_block_size_bytes(self) -> int: @@ -289,6 +272,28 @@ def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None: worker_class, extended_calls, ) + + shared_worker_lock = kwargs.pop("shared_worker_lock", None) + if shared_worker_lock is None: + msg = ( + "Missing `shared_worker_lock` argument from executor. " + "This argument is needed for mm_processor_cache_type='shm'." + ) + + mm_config = self.vllm_config.model_config.multimodal_config + if mm_config and mm_config.mm_processor_cache_type == "shm": + raise ValueError(msg) + else: + logger.warning_once(msg) + + self.mm_receiver_cache = None + else: + self.mm_receiver_cache = worker_receiver_cache_from_config( + self.vllm_config, + MULTIMODAL_REGISTRY, + shared_worker_lock, + ) + with set_current_vllm_config(self.vllm_config): # To make vLLM config available during worker initialization self.worker = worker_class(**kwargs) @@ -304,7 +309,7 @@ def init_device(self): # To make vLLM config available during device initialization self.worker.init_device() # type: ignore - def execute_method(self, method: Union[str, bytes], *args, **kwargs): + def execute_method(self, method: str | bytes, *args, **kwargs): try: # method resolution order: # if a method is defined in this class, it will be called directly. @@ -323,5 +328,34 @@ def execute_method(self, method: Union[str, bytes], *args, **kwargs): logger.exception(msg) raise e - def __getattr__(self, attr): + def __getattr__(self, attr: str): return getattr(self.worker, attr) + + def _apply_mm_cache(self, scheduler_output: SchedulerOutput) -> None: + mm_cache = self.mm_receiver_cache + if mm_cache is None: + return + + for req_data in scheduler_output.scheduled_new_reqs: + req_data.mm_features = mm_cache.get_and_update_features( + req_data.mm_features + ) + + def execute_model( + self, + scheduler_output: SchedulerOutput, + *args, + **kwargs, + ) -> ModelRunnerOutput: + self._apply_mm_cache(scheduler_output) + + assert self.worker is not None + return self.worker.execute_model(scheduler_output, *args, **kwargs) + + def reset_mm_cache(self) -> None: + mm_receiver_cache = self.mm_receiver_cache + if mm_receiver_cache is not None: + mm_receiver_cache.clear_cache() + + assert self.worker is not None + self.worker.reset_mm_cache() diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index a1e54628d9ed..31fa3f3bd6ac 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -39,6 +39,7 @@ def __init__( # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace if envs.VLLM_TORCH_PROFILER_DIR: torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" logger.info( "Profiling enabled. Traces will be saved to: %s", torch_profiler_trace_dir, @@ -61,7 +62,7 @@ def __init__( with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True + torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True ), ) else: